diff --git a/README.md b/README.md index 7e6876da7..c9fc3cef2 100644 --- a/README.md +++ b/README.md @@ -91,9 +91,9 @@ project. Contributions are encouraged! not yet implemented. Contributions are welcome! - Of the 21859 tests in the QT3 test suite (vendored into `vendor/xpath-tests`) - that match the features we support (so excluding XQuery tests), we support - over have 20130 at the time of writing. The failures are mostly due to - missing library implementation. + that match the features we support (so excluding XQuery tests), we pass 20221 + at the time of writing. The remaining gaps are mostly missing library + functions and some parsing/formatting behavior. - XMLSchema support. While the basic `xs:*` data types as defined by XML Schema are implemented, deep XML Schema integration does not exist. @@ -146,8 +146,8 @@ The Xee project is composed of many crates. Here is a quick overview: - [`xee-testrunner`](xee-testrunner) - a testrunner that can run the QT3 conformance suite of XPath tests (in `vendor/xpath-tests`). It has also - been generalized towards supporting running XSLT conformance tests, but that - implementation is not complete yet. + been generalized to run the XSLT conformance tests (in `vendor/xslt-tests`), + though coverage is still partial and many tests are filtered. - [`xee-xpath-lexer`](xee-xpath-lexer) - A lexer for XPath expressions. diff --git a/conformance/README.md b/conformance/README.md index 6b6bf1d5f..ab3d02c76 100644 --- a/conformance/README.md +++ b/conformance/README.md @@ -30,7 +30,8 @@ there are gaps all over the place. `xsl:template`, `xsl:value-of`, `xsl:variable`, `xsl:if` `xsl:choose`, `xsl:when`, `xsl:otherwise`, `xsl:for-each`, `xsl:copy`, `xsl:copy-of`, -`xsl:sequence`, `xsl:apply-templates`, `xsl:text`, `xsl:attribute`, +`xsl:sequence`, `xsl:apply-templates`, `xsl:apply-imports`, `xsl:next-match`, +`xsl:call-template`, `xsl:try`, `xsl:catch`, `xsl:text`, `xsl:attribute`, `xsl:namespace`, `xsl:comment`, `xsl:processing-instruction` all have their core behavior implemented. @@ -38,7 +39,6 @@ See [xslt.md](xslt.md) for details. ### Tests -One big task is to support XSLT tests in the test runner - the test runner has -been prepared for this by making it generic and the test suite has been -imported into `vendor`, but XSLT support in the test runner is not yet -implemented. +The XSLT test suite can now be run with the test runner, but coverage is still +partial and many tests are filtered. The test runner continues to improve as +more XSLT functionality lands. diff --git a/conformance/xslt.md b/conformance/xslt.md index aa9433ee8..c4cf4104a 100644 --- a/conformance/xslt.md +++ b/conformance/xslt.md @@ -20,33 +20,45 @@ We have regexml now, so should be able to implement. ## xsl:apply-imports -TODO: import subsystem +Basic support with import precedence and non-tunnel `xsl:with-param`. + +Not yet: + +- Packages/expose/override +- Tunnel params ## xsl:apply-templates -Not yet: +Basic support: + +- Named/unnamed/current modes +- Non-tunnel `xsl:with-param` +- Built-in templates via `xsl:mode` on-no-match -- Mode support +Not yet: - Variables in patterns +- Certain axes -- Rooted patterns +## xsl:assert -- Certain axes +Basic support: -- Fallback templates +- Evaluates `test`; on failure raises `XTMM9001` or the supplied `error-code`. +- Uses `select` or the sequence constructor for the error message. -## xsl:assert +Not yet: -TODO +- Assertion disable/enable toggles outside `use-when` and the + `enable_assertions` dependency. ## xsl:attribute Cannot add after normal child. -Not yet: +Supports `type` (recorded in the type table; no schema validation). -- type +Not yet: - validation @@ -60,11 +72,11 @@ TODO: xsl:iterate ## xsl:call-template -TODO: function subsystem +Basic support for named templates and non-tunnel `xsl:with-param`. ## xsl:catch -TODO +Supported as part of `xsl:try` (no `xsl:fallback`). ## xsl:character-map @@ -84,33 +96,38 @@ TODO ## xsl:copy +Supports `type` (recorded in the type table; no schema validation). + Not yet: -- copy-namespaces, inherit-namespaces, use-attribute-set, type, validation +- copy-namespaces, inherit-namespaces, use-attribute-sets, validation ## xsl:copy-of +Supports `type` (recorded in the type table; no schema validation). + Not yet: -- copy-accumulators, copy-namespaces, type, validation +- copy-accumulators, copy-namespaces, validation ## xsl:decimal-format -TOD: awaiting xee-format +TODO: awaiting xee-format ## xsl:document -TODO: nodes +Basic document node construction with sequence constructor content. +Supports `type` (recorded in the type table; no schema validation). ## xsl:element +Supports `type` (recorded in the type table; no schema validation). + Not yet: - inherit-namespaces -- use-attribute sets - -- type +- use-attribute-sets - validation @@ -128,7 +145,7 @@ TODO ## xsl:for-each -Todo: +TODO: - xsl:sort support @@ -142,7 +159,12 @@ TODO ## xsl:function -TODO: function subsystem +Basic support for user-defined functions. + +Not yet: + +- Visibility/overriding/caching/streamability +- Default values for function parameters ## xsl:global-context-item @@ -154,7 +176,11 @@ Done ## xsl:import -TODO: import subsystem +Basic file-based import resolution with import precedence. + +Not yet: + +- Packages/expose/override ## xsl:import-schema @@ -162,7 +188,7 @@ TODO: schema support ## xsl:include -TODO: imoprt subsystem +Basic file-based include resolution. ## xsl:iterate @@ -206,7 +232,11 @@ TODO ## xsl:mode -TODO +Supports `on-no-match` for built-in template behavior. + +Not yet: + +- Streamability and other mode attributes ## xsl:namespace @@ -224,7 +254,7 @@ TODO: xsl:iterate ## xsl:next-match -TODO: template rule subsystem, import system +Basic support with import precedence and non-tunnel `xsl:with-param`. ## xsl:non-matching-substring @@ -252,10 +282,6 @@ Done ## xsl:output -TODO - -## xsl:output - TODO: output method subsystem ## xsl:output-character @@ -272,7 +298,11 @@ TODO: import subsystem ## xsl:param -TODO: function subsystem +Supports global and template parameters (non-tunnel). + +Not yet: + +- Static params and visibility ## xsl:perform-sort @@ -308,19 +338,17 @@ TODO ## xsl:stylesheet -Not yet: all of the attibutes +Not yet: all of the attributes ## xsl:template Including priority. -Not yet: +Named templates and mode selection are supported. -- match: variable support, rooted paths, certain axes - -- name +Not yet: -- mode +- match: variable support, certain axes - as @@ -330,7 +358,7 @@ Not yet: Not yet: -- depecrated disable-output-escaping +- deprecated disable-output-escaping ## xsl:transform @@ -338,7 +366,11 @@ See xsl:stylesheet ## xsl:try -TODO +Basic support for `xsl:try`/`xsl:catch`, including `rollback-output`. + +Not yet: + +- `xsl:fallback` ## xsl:use-package @@ -352,13 +384,12 @@ Done except: ## xsl:variable -Not yet: - -- compile-time variables used as global variables +Basic support for global and local variables (non-static). -- global variables +Not yet: -- attributes: as, visbiility +- static variables (compile-time) +- attributes: as, visibility ## xsl:when @@ -370,4 +401,8 @@ TODO ## xsl:with-param -Todo: function subsystem +Supported for apply-templates/apply-imports/next-match/call-template. + +Not yet: + +- Tunnel params diff --git a/hacking.md b/hacking.md index 35006a156..f405b9168 100644 --- a/hacking.md +++ b/hacking.md @@ -10,7 +10,7 @@ conformance test suite of more than 20,000 tests. ## XPath -### Xpath functions +### XPath functions The [XPath and XQuery Functions and Operators](https://www.w3.org/TR/xpath-functions-31/) specification describes @@ -57,7 +57,7 @@ from the interpreter. Any function can take special optional arguments `context` and `interpreter` as the first two arguments. In this case the system automatically injects these objects into your function. -One you've created a function it needs to be registered at the bottom in +Once you've created a function it needs to be registered at the bottom in `static_function_descriptions` using the `wrap_xpath_fn!` macro. In case you're creating a new library module, you need to hook up your new @@ -83,16 +83,16 @@ tests we already *know* should pass as they did in the past. This gives a result in the end that reads like this: ``` -Total: 31812 Supported: 21859 Passed: 19987 Failed: 0 Error: 0 WrongE: 0 -Filtered: 1872 Unsupported: 9953 +Total: 31812 Supported: 21859 Passed: 20221 Failed: 0 Error: 0 WrongE: 0 +Filtered: 1638 Unsupported: 9953 ``` `Total` is the total amount of tests in the suite. This includes features that we don't support, most prominently XQuery, so `Supported` is the total amount -of tests that are relevant to use. `Passed` indicates how many of the tests -that behave as expected, `Failed`, the tests that failed (wrong answers), -`Error` the tests that had an unexpected error, `WrongE` those tests that -expect an error but the wrong error is returned. +of tests that are relevant to use. `Passed` indicates how many tests behave as +expected, `Failed` the tests that failed (wrong answers), `Error` the tests +that had an unexpected error, and `WrongE` those tests that expect an error +but the wrong error is returned. `Filtered` is those tests we want to support but do not work yet - we know they fail in advance so they're filtered out by `check`. @@ -117,17 +117,17 @@ are no regressions - `Failed`, `Error` and `WrongE` should remain at 0. ### Zooming in on tests -You can run `all` against a whole test xml file. To rerun just the `node-name` -tests. +You can run `all` against a whole test XML file. To rerun just the `node-name` +tests, use: ``` cargo run --release -- -v all ../vendor/xpath-tests/fn/node-name.xml ``` Thanks to the `-v` option you can see the test names and you can also see more -information about test passing and failure. +information about pass/fail. -You can also filter tests further by using (part of) its name in the XML file +You can also filter tests further by using (part of) their name in the XML file ``` cargo run --release -- -v all ../vendor/xpath-tests/fn/node-name.xml fn-node-name-1 @@ -149,8 +149,8 @@ and then rerunning: cargo run --release -- initialize ../vendor/xpath-tests/ ``` -This regenerates the `filters` file from scratch. This means that are newly -failing are added to it, *decreasing* the total amount of tests that pass +This regenerates the `filters` file from scratch. This means tests that are +newly failing are added to it, *decreasing* the total amount of tests that pass successfully. When you do this it makes sense to do a diff with the previous version of `filters` to see whether you've made any mistakes and caused too many tests to fail. @@ -177,15 +177,15 @@ cargo run --release -- check ../vendor/xslt-tests/ You can improve tests just as described for XPath. -Note that some features of the XSLT test runner such as passing in parameters -are not yet working correctly; so if a test fails you may also want to suspect -the test runner setup, not just the implementation. Like for the XPath -conformance tests we intend to improve support for the testrunner -incrementally. +The runner reads `` and `initial-template` from the test metadata and +honors the `enable_assertions` dependency for `xsl:assert`. Coverage is still +partial and many tests are filtered, so failures can still be due to missing +runner features or missing implementation. Like for the XPath conformance +tests we intend to improve support for the testrunner incrementally. ### Adding XSLT functionality -The XSLT AST is pretty complete, and underlying IR and bytecode interpreter +The XSLT AST is pretty complete, and the underlying IR and bytecode interpreter supports a lot of XSLT functionality already. Much of the effort of adding XSLT functionality is focused on translating the XSLT AST into the IR format. This -is done by `xee-xslt-compiler/src/test_xslt.rs`. +is done by `xee-xslt-compiler/tests/test_xslt.rs`. diff --git a/vendor/xslt-tests/filters b/vendor/xslt-tests/filters index f5218eac4..473a88642 100644 --- a/vendor/xslt-tests/filters +++ b/vendor/xslt-tests/filters @@ -571,15 +571,8 @@ as-3605 as-3701 = aspiring = assert -assert-001 -assert-002 -assert-003 -assert-004 assert-005 -assert-006 assert-007 -assert-008 -assert-009 assert-010 = attribute attribute-0001 @@ -5609,38 +5602,6 @@ package-version-910 package-version-911 package-version-912a package-version-912b -= param -param-0101 -param-0102 -param-0103 -param-0104 -param-0105 -param-0106 -param-0107 -param-0108 -param-0109 -param-0110 -param-0111 -param-0112 -param-0113 -param-0114 -param-0115 -param-0116 -param-0117 -param-0118 -param-0119 -param-0120 -param-0201 -param-0301 -param-0401 -param-0402 -param-0403 -param-0501 -param-0601 -param-0602 -param-0701 -param-0702 -param-0703 = path path-010 = position diff --git a/xee-interpreter/README.md b/xee-interpreter/README.md index 8061364ea..1727c7139 100644 --- a/xee-interpreter/README.md +++ b/xee-interpreter/README.md @@ -8,8 +8,8 @@ bytecode interpreter, an implementation of XPath data types, and the XPath standard library. This is used by [`xee-xpath`](https://docs.rs/xee-xpath/latest/xee_xpath/) to -implement XPath and can also serve as the engine to execute XSLT code (work in -progress). +implement XPath and can also serve as the engine to execute XSLT code (partial +support; template dispatch, parameters, and try/catch work, but much remains). This is a low-level crate of the [Xee project](https://github.com/Paligo/xee). For the API entry point see diff --git a/xee-interpreter/src/context/dynamic_context.rs b/xee-interpreter/src/context/dynamic_context.rs index fc1fe073c..d6c1249ac 100644 --- a/xee-interpreter/src/context/dynamic_context.rs +++ b/xee-interpreter/src/context/dynamic_context.rs @@ -6,7 +6,7 @@ use crate::function::{self, Function}; use crate::{error::Error, interpreter::Program}; use crate::{interpreter, sequence}; -use super::{DocumentsRef, StaticContext}; +use super::{DocumentsRef, StaticContext, TypeTableRef}; /// A map of variables /// @@ -27,6 +27,7 @@ pub struct DynamicContext<'a> { // multiple spots. We use RefCell to manage that during runtime so we don't // need to make the whole thing immutable. documents: DocumentsRef, + type_table: TypeTableRef, variables: Variables, // TODO: we want to be able to control the creation of this outside, // as it needs to be the same for all evalutions of XSLT I believe @@ -49,6 +50,7 @@ impl<'a> DynamicContext<'a> { program: &'a Program, context_item: Option, documents: DocumentsRef, + type_table: TypeTableRef, variables: Variables, current_datetime: chrono::DateTime, default_collection: Option, @@ -61,6 +63,7 @@ impl<'a> DynamicContext<'a> { program, context_item, documents, + type_table, variables, current_datetime, default_collection, @@ -86,6 +89,10 @@ impl<'a> DynamicContext<'a> { self.documents.clone() } + pub fn type_table(&self) -> TypeTableRef { + self.type_table.clone() + } + /// The variables in this context. pub fn variables(&self) -> &Variables { &self.variables diff --git a/xee-interpreter/src/context/dynamic_context_builder.rs b/xee-interpreter/src/context/dynamic_context_builder.rs index 4f33a8323..641f88412 100644 --- a/xee-interpreter/src/context/dynamic_context_builder.rs +++ b/xee-interpreter/src/context/dynamic_context_builder.rs @@ -19,6 +19,7 @@ pub struct DynamicContextBuilder<'a> { program: &'a interpreter::Program, context_item: Option, documents: DocumentsRef, + type_table: TypeTableRef, variables: Variables, current_datetime: chrono::DateTime, default_collection: Option, @@ -51,6 +52,35 @@ impl DocumentsRef { } } +#[derive(Debug, Clone)] +pub struct TypeTableRef(Rc>); + +impl Deref for TypeTableRef { + type Target = RefCell; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for TypeTableRef { + fn from(table: xml::TypeTable) -> Self { + Self(Rc::new(RefCell::new(table))) + } +} + +impl TypeTableRef { + pub fn new() -> Self { + Self(Rc::new(RefCell::new(xml::TypeTable::new()))) + } +} + +impl Default for TypeTableRef { + fn default() -> Self { + Self::new() + } +} + impl Default for DocumentsRef { fn default() -> Self { Self::new() @@ -64,6 +94,7 @@ impl<'a> DynamicContextBuilder<'a> { program, context_item: None, documents: DocumentsRef::new(), + type_table: TypeTableRef::new(), variables: Variables::new(), current_datetime: chrono::offset::Local::now().into(), default_collection: None, @@ -96,6 +127,11 @@ impl<'a> DynamicContextBuilder<'a> { self } + pub fn type_table(&mut self, type_table: impl Into) -> &mut Self { + self.type_table = type_table.into(); + self + } + /// Set the variables of the [`DynamicContext`]. /// /// Without this, the [`DynamicContext`] will have no variables. @@ -173,6 +209,7 @@ impl<'a> DynamicContextBuilder<'a> { self.program, self.context_item.clone(), self.documents.clone(), + self.type_table.clone(), self.variables.clone(), self.current_datetime, self.default_collection.clone(), diff --git a/xee-interpreter/src/context/mod.rs b/xee-interpreter/src/context/mod.rs index e1028fddd..5ad9e43a9 100644 --- a/xee-interpreter/src/context/mod.rs +++ b/xee-interpreter/src/context/mod.rs @@ -6,6 +6,6 @@ mod static_context; mod static_context_builder; pub use dynamic_context::{DynamicContext, Variables}; -pub use dynamic_context_builder::{DocumentsRef, DynamicContextBuilder}; +pub use dynamic_context_builder::{DocumentsRef, DynamicContextBuilder, TypeTableRef}; pub use static_context::StaticContext; pub use static_context_builder::StaticContextBuilder; diff --git a/xee-interpreter/src/context/static_context.rs b/xee-interpreter/src/context/static_context.rs index d17125672..69d7883f6 100644 --- a/xee-interpreter/src/context/static_context.rs +++ b/xee-interpreter/src/context/static_context.rs @@ -31,11 +31,12 @@ pub struct StaticContext { // TODO: try to make collations static collations: RefCell, static_base_uri: Option, + assertions_enabled: bool, } impl Default for StaticContext { fn default() -> Self { - Self::new(Namespaces::default(), VariableNames::default(), None) + Self::new(Namespaces::default(), VariableNames::default(), None, true) } } @@ -46,6 +47,7 @@ impl From for StaticContext { functions: &STATIC_FUNCTIONS, collations: RefCell::new(Collations::new()), static_base_uri: None, + assertions_enabled: true, } } } @@ -55,17 +57,19 @@ impl StaticContext { namespaces: Namespaces, variable_names: VariableNames, static_base_uri: Option, + assertions_enabled: bool, ) -> Self { Self { parser_context: XPathParserContext::new(namespaces, variable_names), functions: &STATIC_FUNCTIONS, collations: RefCell::new(Collations::new()), static_base_uri, + assertions_enabled, } } pub fn from_namespaces(namespaces: Namespaces) -> Self { - Self::new(namespaces, VariableNames::default(), None) + Self::new(namespaces, VariableNames::default(), None, true) } pub fn namespaces(&self) -> &Namespaces { @@ -100,6 +104,10 @@ impl StaticContext { self.static_base_uri.as_deref() } + pub fn assertions_enabled(&self) -> bool { + self.assertions_enabled + } + pub(crate) fn collation(&self, uri: &IriReferenceStr) -> error::Result> { self.collations .borrow_mut() diff --git a/xee-interpreter/src/context/static_context_builder.rs b/xee-interpreter/src/context/static_context_builder.rs index aca5a0643..fab79d2e8 100644 --- a/xee-interpreter/src/context/static_context_builder.rs +++ b/xee-interpreter/src/context/static_context_builder.rs @@ -1,17 +1,31 @@ -use ahash::HashMap; +use ahash::{HashMap, HashMapExt}; use iri_string::types::IriAbsoluteString; use xee_name::Namespaces; use xot::xmlname::OwnedName; use crate::context; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct StaticContextBuilder<'a> { variable_names: Vec, namespaces: HashMap<&'a str, &'a str>, default_element_namespace: &'a str, default_function_namespace: &'a str, static_base_uri: Option, + assertions_enabled: bool, +} + +impl Default for StaticContextBuilder<'_> { + fn default() -> Self { + Self { + variable_names: Vec::new(), + namespaces: HashMap::new(), + default_element_namespace: "", + default_function_namespace: "", + static_base_uri: None, + assertions_enabled: true, + } + } } impl<'a> StaticContextBuilder<'a> { @@ -75,6 +89,11 @@ impl<'a> StaticContextBuilder<'a> { self } + pub fn assertions_enabled(&mut self, enabled: bool) -> &mut Self { + self.assertions_enabled = enabled; + self + } + /// Build the static context. /// /// This will always include the default known namespaces for @@ -96,7 +115,12 @@ impl<'a> StaticContextBuilder<'a> { default_function_namespace.to_string(), ); let variable_names = self.variable_names.clone().into_iter().collect(); - context::StaticContext::new(namespaces, variable_names, self.static_base_uri.clone()) + context::StaticContext::new( + namespaces, + variable_names, + self.static_base_uri.clone(), + self.assertions_enabled, + ) } } diff --git a/xee-interpreter/src/declaration/decl.rs b/xee-interpreter/src/declaration/decl.rs index d9b6a1dea..2162db645 100644 --- a/xee-interpreter/src/declaration/decl.rs +++ b/xee-interpreter/src/declaration/decl.rs @@ -1,14 +1,81 @@ -use crate::{function, pattern::ModeLookup}; +use ahash::{HashMap, HashMapExt}; + +use crate::{ + function, + pattern::{ModeId, RuleEntry}, +}; +use xee_xpath_ast::Pattern; + +#[derive(Debug, Clone)] +pub struct GlobalParam { + pub name: xot::xmlname::OwnedName, + pub required: bool, + pub overrideable: bool, + pub default: Option, +} + +#[derive(Debug, Clone)] +pub struct TemplateParam { + pub name: xot::xmlname::OwnedName, + pub required: bool, + pub default: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CatchError { + Any, + Namespace(String), + Local(String), + Name(xot::xmlname::OwnedName), +} + +#[derive(Debug, Clone)] +pub struct CatchClause { + pub errors: Vec, +} + +#[derive(Debug, Clone)] +pub struct TryCatch { + pub rollback_output: bool, + pub catches: Vec, +} #[derive(Debug)] pub struct Declarations { - pub mode_lookup: ModeLookup, + pub mode_rules: HashMap, RuleEntry)>>, + pub mode_configs: HashMap, + pub named_templates: HashMap, + pub user_functions: Vec, + pub global_params: Vec, + pub template_params: HashMap>, + pub try_catches: Vec, } impl Declarations { pub(crate) fn new() -> Self { Self { - mode_lookup: ModeLookup::new(), + mode_rules: HashMap::new(), + mode_configs: HashMap::new(), + named_templates: HashMap::new(), + user_functions: Vec::new(), + global_params: Vec::new(), + template_params: HashMap::new(), + try_catches: Vec::new(), } } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OnNoMatch { + DeepCopy, + ShallowCopy, + DeepSkip, + ShallowSkip, + TextOnlyCopy, + Fail, +} + +#[derive(Debug, Clone, Copy)] +pub struct ModeConfig { + pub on_no_match: Option, +} diff --git a/xee-interpreter/src/declaration/mod.rs b/xee-interpreter/src/declaration/mod.rs index 47d699072..85ae76fde 100644 --- a/xee-interpreter/src/declaration/mod.rs +++ b/xee-interpreter/src/declaration/mod.rs @@ -4,4 +4,7 @@ mod decl; mod globalvar; -pub use decl::Declarations; +pub use decl::{ + CatchClause, CatchError, Declarations, GlobalParam, ModeConfig, OnNoMatch, TemplateParam, + TryCatch, +}; diff --git a/xee-interpreter/src/error.rs b/xee-interpreter/src/error.rs index 9a5630977..f3f42abf3 100644 --- a/xee-interpreter/src/error.rs +++ b/xee-interpreter/src/error.rs @@ -540,6 +540,34 @@ pub enum Error { /// cannot handle such characters. FOXT0006, + /// Required parameter has a default value. + /// + /// It is a static error if xsl:param specifies required="yes" and also + /// specifies a select attribute or sequence constructor. + XTSE0010, + /// Invalid attribute value. + /// + /// It is a static error if a stylesheet attribute has a value that is not + /// allowed by the specification. + XTSE0020, + /// Attribute not permitted on an XSLT element. + /// + /// It is a static error if an attribute is not permitted for the + /// containing XSLT element. + XTSE0090, + /// Required stylesheet parameter not supplied. + /// + /// It is a dynamic error if a required stylesheet parameter is not supplied. + XTDE0050, + /// Required template parameter not supplied. + /// + /// It is a dynamic error if a required template parameter is not supplied. + XTDE0060, + /// Required parameter not supplied (alternative code). + /// + /// Some processors use XTDE0700 for missing required parameters. + XTDE0700, + /// Duplicate global variable name. /// /// It is a static error if a package contains more than one non-hidden diff --git a/xee-interpreter/src/function/map.rs b/xee-interpreter/src/function/map.rs index ca4ed51ed..56f3c58d6 100644 --- a/xee-interpreter/src/function/map.rs +++ b/xee-interpreter/src/function/map.rs @@ -5,6 +5,7 @@ use xee_schema_type::Xs; use xee_xpath_ast::ast; use xot::Xot; +use crate::xml::TypeTable; use crate::{atomic, context, error, sequence, string}; /// An XPath Map (a collection of key-value pairs). @@ -145,11 +146,18 @@ impl Map { atomic_type: Xs, static_context: &context::StaticContext, xot: &Xot, + type_table: &TypeTable, ) -> error::Result> { match self { - Map::Empty(map) => map.get_as_type(key, occurrence, atomic_type, static_context, xot), - Map::One(map) => map.get_as_type(key, occurrence, atomic_type, static_context, xot), - Map::Many(map) => map.get_as_type(key, occurrence, atomic_type, static_context, xot), + Map::Empty(map) => { + map.get_as_type(key, occurrence, atomic_type, static_context, xot, type_table) + } + Map::One(map) => { + map.get_as_type(key, occurrence, atomic_type, static_context, xot, type_table) + } + Map::Many(map) => { + map.get_as_type(key, occurrence, atomic_type, static_context, xot, type_table) + } } } @@ -286,6 +294,7 @@ pub(crate) trait Mappable { atomic_type: Xs, static_context: &context::StaticContext, xot: &Xot, + type_table: &TypeTable, ) -> error::Result> { let value = self.get(key); let value = match value { @@ -302,6 +311,7 @@ pub(crate) trait Mappable { &sequence_type, static_context, xot, + type_table, // typed function tests can't be invoked &|_function| unreachable!(), )?, diff --git a/xee-interpreter/src/interpreter/instruction.rs b/xee-interpreter/src/interpreter/instruction.rs index f7b4ff14e..c0d8a1387 100644 --- a/xee-interpreter/src/interpreter/instruction.rs +++ b/xee-interpreter/src/interpreter/instruction.rs @@ -75,11 +75,22 @@ pub enum Instruction { XmlComment, XmlProcessingInstruction, XmlAppend, + XmlSetType(u16), CopyShallow, CopyDeep, ApplyTemplates(u16), + ApplyTemplatesWithParams(u16, u16), + ApplyTemplatesCurrent, + ApplyTemplatesCurrentWithParams(u16), + ApplyImports, + ApplyImportsWithParams(u16), + ApplyNextMatch, + ApplyNextMatchWithParams(u16), + CallTemplate(u16), + CallTemplateWithParams(u16, u16), PrintTop, PrintStack, + TryCatch(u16), } #[derive(Debug, ToPrimitive, FromPrimitive, Clone, PartialEq, Eq, Hash)] @@ -154,11 +165,22 @@ pub(crate) enum EncodedInstruction { XmlComment, XmlProcessingInstruction, XmlAppend, + XmlSetType, ApplyTemplates, + ApplyTemplatesWithParams, + ApplyTemplatesCurrent, + ApplyTemplatesCurrentWithParams, + ApplyImports, + ApplyImportsWithParams, + ApplyNextMatch, + ApplyNextMatchWithParams, + CallTemplate, + CallTemplateWithParams, CopyShallow, CopyDeep, PrintTop, PrintStack, + TryCatch, } // decode a single instruction from the slice @@ -283,14 +305,57 @@ pub(crate) fn decode_instruction(bytes: &[u8]) -> (Instruction, usize) { EncodedInstruction::XmlComment => (Instruction::XmlComment, 1), EncodedInstruction::XmlProcessingInstruction => (Instruction::XmlProcessingInstruction, 1), EncodedInstruction::XmlAppend => (Instruction::XmlAppend, 1), + EncodedInstruction::XmlSetType => { + let xs = u16::from_le_bytes([bytes[1], bytes[2]]); + (Instruction::XmlSetType(xs), 3) + } EncodedInstruction::CopyShallow => (Instruction::CopyShallow, 1), EncodedInstruction::CopyDeep => (Instruction::CopyDeep, 1), EncodedInstruction::ApplyTemplates => { let mode_id = u16::from_le_bytes([bytes[1], bytes[2]]); (Instruction::ApplyTemplates(mode_id), 3) } + EncodedInstruction::ApplyTemplatesWithParams => { + let mode_id = u16::from_le_bytes([bytes[1], bytes[2]]); + let param_count = u16::from_le_bytes([bytes[3], bytes[4]]); + (Instruction::ApplyTemplatesWithParams(mode_id, param_count), 5) + } + EncodedInstruction::ApplyTemplatesCurrent => (Instruction::ApplyTemplatesCurrent, 1), + EncodedInstruction::ApplyTemplatesCurrentWithParams => { + let param_count = u16::from_le_bytes([bytes[1], bytes[2]]); + ( + Instruction::ApplyTemplatesCurrentWithParams(param_count), + 3, + ) + } + EncodedInstruction::ApplyImports => (Instruction::ApplyImports, 1), + EncodedInstruction::ApplyImportsWithParams => { + let param_count = u16::from_le_bytes([bytes[1], bytes[2]]); + (Instruction::ApplyImportsWithParams(param_count), 3) + } + EncodedInstruction::ApplyNextMatch => (Instruction::ApplyNextMatch, 1), + EncodedInstruction::ApplyNextMatchWithParams => { + let param_count = u16::from_le_bytes([bytes[1], bytes[2]]); + (Instruction::ApplyNextMatchWithParams(param_count), 3) + } + EncodedInstruction::CallTemplate => { + let template_id = u16::from_le_bytes([bytes[1], bytes[2]]); + (Instruction::CallTemplate(template_id), 3) + } + EncodedInstruction::CallTemplateWithParams => { + let template_id = u16::from_le_bytes([bytes[1], bytes[2]]); + let param_count = u16::from_le_bytes([bytes[3], bytes[4]]); + ( + Instruction::CallTemplateWithParams(template_id, param_count), + 5, + ) + } EncodedInstruction::PrintTop => (Instruction::PrintTop, 1), EncodedInstruction::PrintStack => (Instruction::PrintStack, 1), + EncodedInstruction::TryCatch => { + let try_catch_id = u16::from_le_bytes([bytes[1], bytes[2]]); + (Instruction::TryCatch(try_catch_id), 3) + } } } @@ -435,14 +500,57 @@ pub fn encode_instruction(instruction: Instruction, bytes: &mut Vec) { .unwrap(), ), Instruction::XmlAppend => bytes.push(EncodedInstruction::XmlAppend.to_u8().unwrap()), + Instruction::XmlSetType(xs) => { + bytes.push(EncodedInstruction::XmlSetType.to_u8().unwrap()); + bytes.extend_from_slice(&xs.to_le_bytes()); + } Instruction::CopyShallow => bytes.push(EncodedInstruction::CopyShallow.to_u8().unwrap()), Instruction::CopyDeep => bytes.push(EncodedInstruction::CopyDeep.to_u8().unwrap()), Instruction::ApplyTemplates(mode_id) => { bytes.push(EncodedInstruction::ApplyTemplates.to_u8().unwrap()); bytes.extend_from_slice(&mode_id.to_le_bytes()); } + Instruction::ApplyTemplatesWithParams(mode_id, param_count) => { + bytes.push(EncodedInstruction::ApplyTemplatesWithParams.to_u8().unwrap()); + bytes.extend_from_slice(&mode_id.to_le_bytes()); + bytes.extend_from_slice(¶m_count.to_le_bytes()); + } + Instruction::ApplyTemplatesCurrent => { + bytes.push(EncodedInstruction::ApplyTemplatesCurrent.to_u8().unwrap()); + } + Instruction::ApplyTemplatesCurrentWithParams(param_count) => { + bytes.push(EncodedInstruction::ApplyTemplatesCurrentWithParams.to_u8().unwrap()); + bytes.extend_from_slice(¶m_count.to_le_bytes()); + } + Instruction::ApplyImports => { + bytes.push(EncodedInstruction::ApplyImports.to_u8().unwrap()); + } + Instruction::ApplyImportsWithParams(param_count) => { + bytes.push(EncodedInstruction::ApplyImportsWithParams.to_u8().unwrap()); + bytes.extend_from_slice(¶m_count.to_le_bytes()); + } + Instruction::ApplyNextMatch => { + bytes.push(EncodedInstruction::ApplyNextMatch.to_u8().unwrap()); + } + Instruction::ApplyNextMatchWithParams(param_count) => { + bytes.push(EncodedInstruction::ApplyNextMatchWithParams.to_u8().unwrap()); + bytes.extend_from_slice(¶m_count.to_le_bytes()); + } + Instruction::CallTemplate(template_id) => { + bytes.push(EncodedInstruction::CallTemplate.to_u8().unwrap()); + bytes.extend_from_slice(&template_id.to_le_bytes()); + } + Instruction::CallTemplateWithParams(template_id, param_count) => { + bytes.push(EncodedInstruction::CallTemplateWithParams.to_u8().unwrap()); + bytes.extend_from_slice(&template_id.to_le_bytes()); + bytes.extend_from_slice(¶m_count.to_le_bytes()); + } Instruction::PrintTop => bytes.push(EncodedInstruction::PrintTop.to_u8().unwrap()), Instruction::PrintStack => bytes.push(EncodedInstruction::PrintStack.to_u8().unwrap()), + Instruction::TryCatch(try_catch_id) => { + bytes.push(EncodedInstruction::TryCatch.to_u8().unwrap()); + bytes.extend_from_slice(&try_catch_id.to_le_bytes()); + } } } @@ -528,8 +636,17 @@ pub fn instruction_size(instruction: &Instruction) -> usize { | Instruction::InstanceOf(_) | Instruction::Treat(_) | Instruction::ReturnConvert(_) - | Instruction::JumpIfFalse(_) => 3, - Instruction::ApplyTemplates(_) => 3, + | Instruction::XmlSetType(_) + | Instruction::JumpIfFalse(_) + | Instruction::TryCatch(_) => 3, + Instruction::ApplyTemplates(_) | Instruction::CallTemplate(_) => 3, + Instruction::ApplyImports => 1, + Instruction::ApplyImportsWithParams(_) => 3, + Instruction::ApplyNextMatch => 1, + Instruction::ApplyNextMatchWithParams(_) => 3, + Instruction::ApplyTemplatesWithParams(_, _) | Instruction::CallTemplateWithParams(_, _) => 5, + Instruction::ApplyTemplatesCurrent => 1, + Instruction::ApplyTemplatesCurrentWithParams(_) => 3, } } diff --git a/xee-interpreter/src/interpreter/interpret.rs b/xee-interpreter/src/interpreter/interpret.rs index 7e2a34a5a..66e2f1220 100644 --- a/xee-interpreter/src/interpreter/interpret.rs +++ b/xee-interpreter/src/interpreter/interpret.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering; use std::rc::Rc; +use ahash::AHashMap; use ibig::{ibig, IBig}; use xee_name::Name; @@ -14,6 +15,7 @@ use crate::atomic::{ op_add, op_div, op_idiv, op_mod, op_multiply, op_subtract, OpEq, OpGe, OpGt, OpLe, OpLt, OpNe, }; use crate::context::DynamicContext; +use crate::declaration::OnNoMatch; use crate::function; use crate::pattern::PredicateMatcher; use crate::sequence; @@ -48,10 +50,10 @@ impl From for ContextInfo { } impl<'a> Interpreter<'a> { - pub fn new(runnable: &'a Runnable<'a>, xot: &'a mut Xot) -> Self { + pub fn new(runnable: &'a Runnable<'a>, xot: &'a mut Xot, type_table: crate::context::TypeTableRef) -> Self { Interpreter { runnable, - state: State::new(xot), + state: State::new_with_type_table(xot, type_table), } } @@ -64,6 +66,7 @@ impl<'a> Interpreter<'a> { } pub fn start(&mut self, context_info: ContextInfo, arguments: Vec) { + self.state.set_global_params(arguments.clone()); self.start_function(self.runnable.program().main_id(), context_info, arguments) } @@ -362,7 +365,10 @@ impl<'a> Interpreter<'a> { let step_id = self.read_u16(); let node: xot::Node = self.state.pop()?.try_into()?; let step = &(self.current_inline_function().steps[step_id as usize]); - let value = xml::resolve_step(step, node, self.state.xot()); + let value = { + let type_table = self.state.type_table.borrow(); + xml::resolve_step(step, node, self.state.xot(), &type_table) + }; self.state.push(value); } EncodedInstruction::Deduplicate => { @@ -385,13 +391,16 @@ impl<'a> Interpreter<'a> { let sequence = self.state.pop()?; let sequence_type = &(self.current_inline_function().sequence_types[sequence_type_id as usize]); - - let sequence = sequence.sequence_type_matching_function_conversion( - sequence_type, - self.runnable.static_context(), - self.state.xot(), - &|function| self.runnable.function_info(function).signature(), - )?; + let sequence = { + let type_table = self.state.type_table.borrow(); + sequence.sequence_type_matching_function_conversion( + sequence_type, + self.runnable.static_context(), + self.state.xot(), + &type_table, + &|function| self.runnable.function_info(function).signature(), + )? + }; self.state.push(sequence); } EncodedInstruction::LetDone => { @@ -433,11 +442,15 @@ impl<'a> Interpreter<'a> { let sequence = self.state.pop()?; let sequence_type = &(self.current_inline_function().sequence_types[sequence_type_id as usize]); - let matches = sequence.sequence_type_matching( - sequence_type, - self.state.xot(), - &|function| self.runnable.function_info(function).signature(), - ); + let matches = { + let type_table = self.state.type_table.borrow(); + sequence.sequence_type_matching( + sequence_type, + self.state.xot(), + &type_table, + &|function| self.runnable.function_info(function).signature(), + ) + }; if matches.is_ok() { self.state.push(true); } else { @@ -449,11 +462,15 @@ impl<'a> Interpreter<'a> { let sequence = self.state.top()?; let sequence_type = &(self.current_inline_function().sequence_types[sequence_type_id as usize]); - let matches = sequence.sequence_type_matching( - sequence_type, - self.state.xot(), - &|function| self.runnable.function_info(function).signature(), - ); + let matches = { + let type_table = self.state.type_table.borrow(); + sequence.sequence_type_matching( + sequence_type, + self.state.xot(), + &type_table, + &|function| self.runnable.function_info(function).signature(), + ) + }; if matches.is_err() { Err(error::Error::XPDY0050)?; } @@ -592,6 +609,16 @@ impl<'a> Interpreter<'a> { let item = sequence::Item::Node(parent_node); self.state.push(item); } + EncodedInstruction::XmlSetType => { + let xs_id = self.read_u16(); + let value = self.state.pop()?; + let xs = Xs::from_u16(xs_id).ok_or(error::Error::XPTY0004)?; + for node in value.nodes() { + let node = node?; + self.state.set_node_type(node, xs); + } + self.state.push(value); + } EncodedInstruction::CopyShallow => { let value = &self.state.pop()?; if value.is_empty() { @@ -622,7 +649,7 @@ impl<'a> Interpreter<'a> { let copy = match &item { sequence::Item::Atomic(_) | sequence::Item::Function(_) => item.clone(), sequence::Item::Node(node) => { - let copied_node = self.state.xot.clone_node(*node); + let copied_node = self.state.clone_node_with_type(*node); sequence::Item::Node(copied_node) } }; @@ -634,9 +661,137 @@ impl<'a> Interpreter<'a> { let value = self.state.pop()?; let mode_id = self.read_u16(); let mode = pattern::ModeId::new(mode_id as usize); - let value = self.apply_templates_sequence(mode, value)?; + let value = self.apply_templates_sequence(mode, value, None)?; + self.state.push(value); + } + EncodedInstruction::ApplyTemplatesWithParams => { + let value = self.state.pop()?; + let mode_id = self.read_u16(); + let param_count = self.read_u16(); + let params = self.pop_with_param_map(param_count)?; + let mode = pattern::ModeId::new(mode_id as usize); + let value = self.apply_templates_sequence(mode, value, Some(¶ms))?; + self.state.push(value); + } + EncodedInstruction::ApplyTemplatesCurrent => { + let value = self.state.pop()?; + let mode = self + .state + .frame() + .mode + .ok_or_else(|| { + error::Error::Unsupported( + "xsl:apply-templates has no current mode".to_string(), + ) + })?; + let value = self.apply_templates_sequence(mode, value, None)?; + self.state.push(value); + } + EncodedInstruction::ApplyTemplatesCurrentWithParams => { + let value = self.state.pop()?; + let param_count = self.read_u16(); + let params = self.pop_with_param_map(param_count)?; + let mode = self + .state + .frame() + .mode + .ok_or_else(|| { + error::Error::Unsupported( + "xsl:apply-templates has no current mode".to_string(), + ) + })?; + let value = self.apply_templates_sequence(mode, value, Some(¶ms))?; + self.state.push(value); + } + EncodedInstruction::ApplyImports => { + let value = self.apply_imports(None)?; + self.state.push(value); + } + EncodedInstruction::ApplyImportsWithParams => { + let param_count = self.read_u16(); + let params = self.pop_with_param_map(param_count)?; + let value = self.apply_imports(Some(¶ms))?; + self.state.push(value); + } + EncodedInstruction::ApplyNextMatch => { + let value = self.apply_next_match(None)?; self.state.push(value); } + EncodedInstruction::ApplyNextMatchWithParams => { + let param_count = self.read_u16(); + let params = self.pop_with_param_map(param_count)?; + let value = self.apply_next_match(Some(¶ms))?; + self.state.push(value); + } + EncodedInstruction::CallTemplate => { + let template_id = self.read_u16(); + let function_id = function::InlineFunctionId::new(template_id as usize); + let value = self.call_named_template(function_id, None)?; + self.state.push(value); + } + EncodedInstruction::CallTemplateWithParams => { + let template_id = self.read_u16(); + let param_count = self.read_u16(); + let params = self.pop_with_param_map(param_count)?; + let function_id = function::InlineFunctionId::new(template_id as usize); + let value = self.call_named_template(function_id, Some(¶ms))?; + self.state.push(value); + } + EncodedInstruction::TryCatch => { + let try_catch_id = self.read_u16(); + let entry = self + .runnable + .program() + .declarations + .try_catches + .get(try_catch_id as usize) + .ok_or_else(|| { + error::Error::Unsupported("Try/catch entry not found".to_string()) + })?; + + let mut catch_functions = Vec::with_capacity(entry.catches.len()); + for _ in 0..entry.catches.len() { + let function = function::Function::try_from(self.state.pop()?)?; + catch_functions.push(function); + } + catch_functions.reverse(); + let try_function = function::Function::try_from(self.state.pop()?)?; + + let (item, position, size) = self.current_context_values(); + let mut base_args = vec![item, position, size]; + base_args.extend( + self.state + .global_params() + .iter() + .cloned() + .map(stack::Value::from), + ); + + let snapshot = self.state.snapshot(entry.rollback_output); + match self.call_function_with_values(&try_function, &base_args) { + Ok(sequence) => { + self.state.push(sequence); + } + Err(err) => { + self.state.restore(snapshot); + let mut handled = false; + for (index, clause) in entry.catches.iter().enumerate() { + if self.match_catch_clause(&err, clause) { + let result = self.call_function_with_values( + &catch_functions[index], + &base_args, + )?; + self.state.push(result); + handled = true; + break; + } + } + if !handled { + return Err(err); + } + } + } + } EncodedInstruction::PrintTop => { let top = self.state.top()?; println!("{:#?}", top); @@ -708,6 +863,39 @@ impl<'a> Interpreter<'a> { self.runnable.function_info(function).arity() } + fn match_catch_clause( + &self, + err: &error::Error, + clause: &crate::declaration::CatchClause, + ) -> bool { + let code = err.code_qname(); + clause.errors.iter().any(|pattern| match pattern { + crate::declaration::CatchError::Any => true, + crate::declaration::CatchError::Name(name) => name == &code, + crate::declaration::CatchError::Local(local) => code.local_name() == local, + crate::declaration::CatchError::Namespace(namespace) => code.namespace() == namespace, + }) + } + + fn current_context_values(&self) -> (stack::Value, stack::Value, stack::Value) { + let function_id = self.state.frame().function(); + if self + .runnable + .program() + .declarations + .user_functions + .contains(&function_id) + { + ( + stack::Value::Absent, + stack::Value::Absent, + stack::Value::Absent, + ) + } else { + self.state.context_values() + } + } + fn call(&mut self, arity: u8) -> error::Result<()> { let function = self.state.callable(arity as usize)?; self.call_function(&function, arity) @@ -735,6 +923,54 @@ impl<'a> Interpreter<'a> { self.state.pop() } + pub(crate) fn call_function_with_values( + &mut self, + function: &function::Function, + arguments: &[stack::Value], + ) -> error::Result { + let item: sequence::Item = function.clone().into(); + self.state.push(item); + let arity = arguments.len() as u8; + for arg in arguments { + self.state.push_value(arg.clone()); + } + self.call_function(function, arity)?; + if matches!(function, function::Function::Inline(_)) { + self.run_actual(self.state.frame().base())?; + } + self.state.pop() + } + + fn call_template_with_rule( + &mut self, + function: &function::Function, + arguments: &[stack::Value], + mode: pattern::ModeId, + rule: pattern::RuleEntry, + rule_index: usize, + ) -> error::Result { + let item: sequence::Item = function.clone().into(); + self.state.push(item); + let arity = arguments.len() as u8; + for arg in arguments { + self.state.push_value(arg.clone()); + } + match function { + function::Function::Static(data) => { + self.call_static(data.id, arity, &data.closure_vars)? + } + function::Function::Inline(data) => { + self.call_inline_with_rule(data.id, arity, mode, rule, rule_index)? + } + function::Function::Array(array) => self.call_array(array, arity as usize)?, + function::Function::Map(map) => self.call_map(map, arity as usize)?, + } + if matches!(function, function::Function::Inline(_)) { + self.run_actual(self.state.frame().base())?; + } + self.state.pop() + } + fn call_function(&mut self, function: &function::Function, arity: u8) -> error::Result<()> { match function { function::Function::Static(data) => { @@ -782,17 +1018,40 @@ impl<'a> Interpreter<'a> { return Err(error::Error::XPTY0004); } - let arguments = self.coerce_arguments(parameter_types, arity)?; + let arguments = self.coerce_inline_arguments(parameter_types, arity)?; // now we have a list of arguments that we want to push back onto the stack // (they are already reversed) for arg in arguments { - self.state.push(arg); + self.state.push_value(arg); } self.state.push_frame(function_id, arity as usize) } + fn call_inline_with_rule( + &mut self, + function_id: function::InlineFunctionId, + arity: u8, + mode: pattern::ModeId, + rule: pattern::RuleEntry, + rule_index: usize, + ) -> error::Result<()> { + let function = self.runnable.program().inline_function(function_id); + let parameter_types = &function.signature.parameter_types(); + if arity as usize != parameter_types.len() { + return Err(error::Error::XPTY0004); + } + + let arguments = self.coerce_inline_arguments(parameter_types, arity)?; + for arg in arguments { + self.state.push_value(arg); + } + + self.state + .push_frame_with_rule(function_id, arity as usize, mode, rule, rule_index) + } + fn coerce_arguments( &mut self, parameter_types: &[Option], @@ -810,20 +1069,64 @@ impl<'a> Interpreter<'a> { let mut arguments = Vec::with_capacity(arity as usize); let static_context = self.runnable.static_context(); let xot = self.state.xot(); - for (parameter_type, stack_value) in parameter_types.iter().zip(stack_values) { - let sequence: sequence::Sequence = stack_value.try_into()?; - if let Some(type_) = parameter_type { - // matching also takes care of function conversion rules - let sequence = sequence.sequence_type_matching_function_conversion( - type_, - static_context, - xot, - &|function| self.runnable.function_info(function).signature(), - )?; - arguments.push(sequence); - } else { - // no need to do any checking or conversion - arguments.push(sequence); + { + let type_table = self.state.type_table.borrow(); + for (parameter_type, stack_value) in parameter_types.iter().zip(stack_values) { + let sequence: sequence::Sequence = stack_value.try_into()?; + if let Some(type_) = parameter_type { + // matching also takes care of function conversion rules + let sequence = sequence.sequence_type_matching_function_conversion( + type_, + static_context, + xot, + &type_table, + &|function| self.runnable.function_info(function).signature(), + )?; + arguments.push(sequence); + } else { + // no need to do any checking or conversion + arguments.push(sequence); + } + } + } + self.state.truncate_arguments(arity as usize); + Ok(arguments) + } + + fn coerce_inline_arguments( + &mut self, + parameter_types: &[Option], + arity: u8, + ) -> error::Result> { + let stack_values = self.state.arguments(arity as usize); + let mut arguments = Vec::with_capacity(arity as usize); + let static_context = self.runnable.static_context(); + let xot = self.state.xot(); + { + let type_table = self.state.type_table.borrow(); + for (parameter_type, stack_value) in parameter_types.iter().zip(stack_values) { + match stack_value { + stack::Value::Absent => { + if parameter_type.is_some() { + return Err(error::Error::XPDY0002); + } + arguments.push(stack::Value::Absent); + } + stack::Value::Sequence(sequence) => { + let sequence = if let Some(type_) = parameter_type { + sequence.clone().sequence_type_matching_function_conversion( + type_, + static_context, + xot, + &type_table, + &|function| self.runnable.function_info(function).signature(), + )? + } else { + sequence.clone() + }; + arguments.push(stack::Value::Sequence(sequence)); + } + } } } self.state.truncate_arguments(arity as usize); @@ -1058,6 +1361,22 @@ impl<'a> Interpreter<'a> { value.atomized_one(self.state.xot()) } + fn pop_with_param_map( + &mut self, + param_count: u16, + ) -> error::Result> { + let mut params = AHashMap::new(); + for _ in 0..param_count { + let name_sequence = self.state.pop()?; + let name_item = name_sequence.one()?; + let name_atomic = name_item.to_atomic()?; + let name: xot::xmlname::OwnedName = name_atomic.try_into()?; + let value = self.state.pop()?; + params.insert(name, value); + } + Ok(params) + } + fn pop_atomic_option(&mut self) -> error::Result> { let value = self.state.pop()?; value.atomized_option(self.state.xot()) @@ -1144,7 +1463,7 @@ impl<'a> Interpreter<'a> { // if we have a parent we're already in another document, // in which case we want to make a clone first let node = if self.state.xot.parent(node).is_some() { - self.state.xot.clone_node(node) + self.state.clone_node_with_type(node) } else { node }; @@ -1174,12 +1493,20 @@ impl<'a> Interpreter<'a> { let value = xot.value(node); match value { // root and element are shallow copies - xot::Value::Document => xot.new_document(), + xot::Value::Document => { + let cloned = xot.new_document(); + self.state.type_table.borrow_mut().copy_type(node, cloned); + cloned + } // TODO: work on copying prefixes - xot::Value::Element(element) => xot.new_element(element.name()), + xot::Value::Element(element) => { + let cloned = xot.new_element(element.name()); + self.state.type_table.borrow_mut().copy_type(node, cloned); + cloned + } // we can clone (deep-copy) these nodes as it's the same // operation as shallow copy - _ => xot.clone_node(node), + _ => self.state.clone_node_with_type(node), } } @@ -1187,12 +1514,106 @@ impl<'a> Interpreter<'a> { &mut self, mode: pattern::ModeId, sequence: sequence::Sequence, + params: Option<&AHashMap>, + ) -> error::Result { + let mut r: Vec = Vec::new(); + let size: IBig = sequence.len().into(); + + for (i, item) in sequence.iter().enumerate() { + let sequence = self.apply_templates_item(mode, item, i, size.clone(), params)?; + if let Some(sequence) = sequence { + for item in sequence.iter() { + r.push(item.clone()); + } + } + } + Ok(r.into()) + } + + fn apply_imports( + &mut self, + params: Option<&AHashMap>, + ) -> error::Result { + let frame = self.state.frame(); + let mode = frame + .mode + .ok_or_else(|| error::Error::Unsupported("xsl:apply-imports has no current mode".to_string()))?; + let rule = frame + .rule + .ok_or_else(|| error::Error::Unsupported("xsl:apply-imports has no current rule".to_string()))?; + let (item, position, size) = self.current_context_values(); + let sequence: sequence::Sequence = item.clone().try_into()?; + let item = match sequence { + sequence::Sequence::One(one) => one.item().clone(), + _ => return Ok(sequence::Sequence::default()), + }; + self.apply_imports_item(mode, item, position, size, params, rule) + .map(|sequence| sequence.unwrap_or_default()) + } + + fn apply_next_match( + &mut self, + params: Option<&AHashMap>, + ) -> error::Result { + let frame = self.state.frame(); + let mode = frame + .mode + .ok_or_else(|| error::Error::Unsupported("xsl:next-match has no current mode".to_string()))?; + let current_rule_index = frame + .rule_index + .ok_or_else(|| error::Error::Unsupported("xsl:next-match has no current rule".to_string()))?; + let (item, position, size) = self.current_context_values(); + let sequence: sequence::Sequence = item.clone().try_into()?; + let item = match sequence { + sequence::Sequence::One(one) => one.item().clone(), + _ => return Ok(sequence::Sequence::default()), + }; + let current_rule = frame + .rule + .ok_or_else(|| error::Error::Unsupported("xsl:next-match has no current rule".to_string()))?; + let rule_entry = self.select_next_rule(mode, &item, current_rule_index, current_rule); + + if let Some((rule_index, rule_entry)) = rule_entry { + if rule_entry.is_builtin { + return self.apply_builtin_template(mode, item, params); + } + let mut base_args: Vec = + vec![stack::Value::from(item.clone()), position.clone(), size.clone()]; + base_args.extend( + self.state + .global_params() + .iter() + .cloned() + .map(stack::Value::from), + ); + let template_args = + self.resolve_template_params(rule_entry.function_id, &base_args, params)?; + let mut arguments = base_args; + arguments.extend(template_args.into_iter().map(stack::Value::from)); + let function = + function::InlineFunctionData::new(rule_entry.function_id, Vec::new()).into(); + self.call_template_with_rule(&function, &arguments, mode, rule_entry, rule_index) + } else { + Ok(sequence::Sequence::default()) + } + } + + fn apply_imports_sequence( + &mut self, + mode: pattern::ModeId, + sequence: sequence::Sequence, + params: Option<&AHashMap>, + current_rule: pattern::RuleEntry, ) -> error::Result { let mut r: Vec = Vec::new(); let size: IBig = sequence.len().into(); for (i, item) in sequence.iter().enumerate() { - let sequence = self.apply_templates_item(mode, item, i, size.clone())?; + let position: IBig = (i + 1).into(); + let position_value = stack::Value::from(atomic::Atomic::from(position)); + let size_value = stack::Value::from(atomic::Atomic::from(size.clone())); + let sequence = + self.apply_imports_item(mode, item, position_value, size_value, params, current_rule)?; if let Some(sequence) = sequence { for item in sequence.iter() { r.push(item.clone()); @@ -1202,41 +1623,401 @@ impl<'a> Interpreter<'a> { Ok(r.into()) } + fn apply_imports_item( + &mut self, + mode: pattern::ModeId, + item: sequence::Item, + position: stack::Value, + size: stack::Value, + params: Option<&AHashMap>, + current_rule: pattern::RuleEntry, + ) -> error::Result> { + let rule_entry = + self.select_apply_imports_rule(mode, &item, current_rule); + + if let Some((rule_index, rule_entry)) = rule_entry { + if rule_entry.is_builtin { + return self.apply_builtin_template(mode, item, params).map(Some); + } + let mut base_args: Vec = vec![ + stack::Value::from(item), + position, + size, + ]; + base_args.extend( + self.state + .global_params() + .iter() + .cloned() + .map(stack::Value::from), + ); + let template_args = + self.resolve_template_params(rule_entry.function_id, &base_args, params)?; + let mut arguments = base_args; + arguments.extend(template_args.into_iter().map(stack::Value::from)); + let function = + function::InlineFunctionData::new(rule_entry.function_id, Vec::new()).into(); + self.call_template_with_rule(&function, &arguments, mode, rule_entry, rule_index) + .map(Some) + } else { + Ok(None) + } + } + fn apply_templates_item( &mut self, mode: pattern::ModeId, item: sequence::Item, position: usize, size: IBig, + params: Option<&AHashMap>, ) -> error::Result> { - let function_id = self.lookup_pattern(mode, &item); + let rule_entry = self.select_first_rule(mode, &item); - if let Some(function_id) = function_id { + if let Some((rule_index, rule_entry)) = rule_entry { + if rule_entry.is_builtin { + return self.apply_builtin_template(mode, item, params).map(Some); + } let position: IBig = (position + 1).into(); - let arguments: Vec = vec![ - item.into(), - atomic::Atomic::from(position).into(), - atomic::Atomic::from(size.clone()).into(), + let mut base_args: Vec = vec![ + stack::Value::from(item), + stack::Value::from(atomic::Atomic::from(position)), + stack::Value::from(atomic::Atomic::from(size.clone())), ]; - let function = function::InlineFunctionData::new(function_id, Vec::new()).into(); - self.call_function_with_arguments(&function, &arguments) + base_args.extend( + self.state + .global_params() + .iter() + .cloned() + .map(stack::Value::from), + ); + let template_args = + self.resolve_template_params(rule_entry.function_id, &base_args, params)?; + let mut arguments = base_args; + arguments.extend(template_args.into_iter().map(stack::Value::from)); + let function = + function::InlineFunctionData::new(rule_entry.function_id, Vec::new()).into(); + self.call_template_with_rule(&function, &arguments, mode, rule_entry, rule_index) .map(Some) } else { Ok(None) } } + fn apply_builtin_template( + &mut self, + mode: pattern::ModeId, + item: sequence::Item, + params: Option<&AHashMap>, + ) -> error::Result { + let on_no_match = self + .runnable + .program() + .declarations + .mode_configs + .get(&mode) + .and_then(|config| config.on_no_match) + .unwrap_or(OnNoMatch::TextOnlyCopy); + + match on_no_match { + OnNoMatch::TextOnlyCopy => { + return match item { + sequence::Item::Node(node) => match self.state.xot().value(node) { + xot::Value::Document | xot::Value::Element(_) => { + let children = self + .state + .xot() + .children(node) + .map(sequence::Item::Node) + .collect::>(); + self.apply_templates_sequence(mode, children.into(), params) + } + xot::Value::Attribute(_) | xot::Value::Text(_) => { + let text = item.string_value(self.state.xot())?; + let text_node = self.state.xot.new_text(&text); + Ok(sequence::Sequence::from(sequence::Item::Node(text_node))) + } + xot::Value::ProcessingInstruction(_) | xot::Value::Comment(_) => { + Ok(sequence::Sequence::default()) + } + _ => Ok(sequence::Sequence::default()), + }, + sequence::Item::Atomic(_) => Ok(sequence::Sequence::default()), + sequence::Item::Function(_) => Err(error::Error::XTDE0450), + }; + } + OnNoMatch::ShallowCopy => { + return match item { + sequence::Item::Node(node) => match self.state.xot().value(node) { + xot::Value::Document | xot::Value::Element(_) => { + let copied = self.shallow_copy_node(node); + let children = self + .state + .xot() + .children(node) + .map(sequence::Item::Node) + .collect::>(); + let child_value = + self.apply_templates_sequence(mode, children.into(), params)?; + self.xml_append(copied, child_value)?; + Ok(sequence::Sequence::from(sequence::Item::Node(copied))) + } + _ => { + let copied = self.state.clone_node_with_type(node); + Ok(sequence::Sequence::from(sequence::Item::Node(copied))) + } + }, + sequence::Item::Atomic(_) => Ok(sequence::Sequence::default()), + sequence::Item::Function(_) => Err(error::Error::XTDE0450), + }; + } + OnNoMatch::DeepCopy => { + return match item { + sequence::Item::Node(node) => { + let copied = self.state.clone_node_with_type(node); + Ok(sequence::Sequence::from(sequence::Item::Node(copied))) + } + sequence::Item::Atomic(_) => Ok(sequence::Sequence::default()), + sequence::Item::Function(_) => Err(error::Error::XTDE0450), + }; + } + OnNoMatch::ShallowSkip => { + return match item { + sequence::Item::Node(node) => match self.state.xot().value(node) { + xot::Value::Document | xot::Value::Element(_) => { + let children = self + .state + .xot() + .children(node) + .map(sequence::Item::Node) + .collect::>(); + self.apply_templates_sequence(mode, children.into(), params) + } + _ => Ok(sequence::Sequence::default()), + }, + sequence::Item::Atomic(_) => Ok(sequence::Sequence::default()), + sequence::Item::Function(_) => Err(error::Error::XTDE0450), + }; + } + OnNoMatch::DeepSkip => { + return match item { + sequence::Item::Node(_) => Ok(sequence::Sequence::default()), + sequence::Item::Atomic(_) => Ok(sequence::Sequence::default()), + sequence::Item::Function(_) => Err(error::Error::XTDE0450), + }; + } + OnNoMatch::Fail => { + return Err(error::Error::Unsupported( + "xsl:mode on-no-match=\"fail\" is not supported".to_string(), + )); + } + } + + } + + pub(crate) fn call_named_template( + &mut self, + function_id: function::InlineFunctionId, + params: Option<&AHashMap>, + ) -> error::Result { + let (item, position, size) = self.current_context_values(); + let mut base_args = vec![item, position, size]; + base_args.extend( + self.state + .global_params() + .iter() + .cloned() + .map(stack::Value::from), + ); + let template_args = self.resolve_template_params(function_id, &base_args, params)?; + let mut arguments = base_args; + arguments.extend(template_args.into_iter().map(stack::Value::from)); + let function = function::InlineFunctionData::new(function_id, Vec::new()).into(); + self.call_function_with_values(&function, &arguments) + } + + fn resolve_template_params( + &mut self, + function_id: function::InlineFunctionId, + base_args: &[stack::Value], + params: Option<&AHashMap>, + ) -> error::Result> { + let template_params = match self + .runnable + .program() + .declarations + .template_params + .get(&function_id) + { + Some(params) => params.clone(), + None => return Ok(Vec::new()), + }; + let mut values: AHashMap = AHashMap::new(); + if let Some(params) = params { + for (name, value) in params { + values.insert(name.clone(), value.clone()); + } + } + let template_param_count = template_params.len(); + let mut resolved: Vec = Vec::with_capacity(template_param_count); + for template_param in &template_params { + let value = if let Some(value) = values.get(&template_param.name) { + value.clone() + } else if let Some(default_fn) = template_param.default { + let mut args = Vec::with_capacity(base_args.len() + template_param_count); + args.extend(base_args.iter().cloned()); + for existing in &resolved { + args.push(stack::Value::from(existing.clone())); + } + for _ in resolved.len()..template_param_count { + args.push(stack::Value::from(sequence::Sequence::default())); + } + let function = function::InlineFunctionData::new(default_fn, Vec::new()).into(); + self.call_function_with_values(&function, &args)? + } else if template_param.required { + return Err(error::Error::XTDE0060); + } else { + sequence::Sequence::default() + }; + values.insert(template_param.name.clone(), value.clone()); + resolved.push(value); + } + Ok(resolved) + } + pub(crate) fn lookup_pattern( &mut self, mode: pattern::ModeId, item: &sequence::Item, - ) -> Option { - self.runnable + ) -> Option { + self.select_first_rule(mode, item).map(|(_, rule)| rule) + } + + pub(crate) fn lookup_next_match( + &mut self, + mode: pattern::ModeId, + item: &sequence::Item, + current_rule: pattern::RuleEntry, + ) -> Option { + let (current_index, _) = self + .select_first_rule(mode, item) + .and_then(|(idx, rule)| if rule == current_rule { Some((idx, rule)) } else { None })?; + self.select_next_rule(mode, item, current_index, current_rule) + .map(|(_, rule)| rule) + } + + fn select_first_rule( + &mut self, + mode: pattern::ModeId, + item: &sequence::Item, + ) -> Option<(usize, pattern::RuleEntry)> { + let rules = self + .runnable + .program() + .declarations + .mode_rules + .get(&mode)?; + for (idx, (pattern, rule)) in rules.iter().enumerate() { + if rule.is_builtin { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + for (idx, (pattern, rule)) in rules.iter().enumerate() { + if !rule.is_builtin { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + None + } + + fn select_next_rule( + &mut self, + mode: pattern::ModeId, + item: &sequence::Item, + current_index: usize, + current_rule: pattern::RuleEntry, + ) -> Option<(usize, pattern::RuleEntry)> { + let rules = self + .runnable .program() .declarations - .mode_lookup - .lookup(mode, |pattern| self.matches(pattern, item)) - .copied() + .mode_rules + .get(&mode)?; + if !current_rule.is_builtin { + for (idx, (pattern, rule)) in rules.iter().enumerate().skip(current_index + 1) { + if rule.is_builtin { + continue; + } + if *rule == current_rule { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + for (idx, (pattern, rule)) in rules.iter().enumerate() { + if !rule.is_builtin { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + return None; + } + + for (idx, (pattern, rule)) in rules.iter().enumerate().skip(current_index + 1) { + if !rule.is_builtin { + continue; + } + if *rule == current_rule { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + None + } + + fn select_apply_imports_rule( + &mut self, + mode: pattern::ModeId, + item: &sequence::Item, + current_rule: pattern::RuleEntry, + ) -> Option<(usize, pattern::RuleEntry)> { + let rules = self + .runnable + .program() + .declarations + .mode_rules + .get(&mode)?; + for (idx, (pattern, rule)) in rules.iter().enumerate() { + if rule.is_builtin { + continue; + } + if rule.import_level <= current_rule.import_level { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + for (idx, (pattern, rule)) in rules.iter().enumerate() { + if !rule.is_builtin { + continue; + } + if self.matches(pattern, item) { + return Some((idx, *rule)); + } + } + None } // The interpreter can return an error for any byte code, in any level of diff --git a/xee-interpreter/src/interpreter/program.rs b/xee-interpreter/src/interpreter/program.rs index 685a91c46..2aa090b1e 100644 --- a/xee-interpreter/src/interpreter/program.rs +++ b/xee-interpreter/src/interpreter/program.rs @@ -1,9 +1,11 @@ use crate::context; use crate::declaration::Declarations; use crate::function; +use crate::span::SourceSpan; use xee_name::Name; use xee_xpath_ast::ast::Span; +use super::instruction::{encode_instruction, Instruction}; use super::Runnable; #[derive(Debug)] @@ -79,6 +81,26 @@ impl Program { function::InlineFunctionId(id) } + pub fn reserve_function_slots(&mut self, count: usize) -> usize { + let start = self.functions.len(); + self.functions + .resize_with(start + count, Self::placeholder_function); + start + } + + pub fn set_function( + &mut self, + function_id: function::InlineFunctionId, + function: function::InlineFunction, + ) { + let index = function_id.0; + if index >= self.functions.len() { + self.functions + .resize_with(index + 1, Self::placeholder_function); + } + self.functions[index] = function; + } + pub(crate) fn get_function(&self, index: usize) -> &function::InlineFunction { &self.functions[index] } @@ -93,6 +115,22 @@ impl Program { pub(crate) fn main_id(&self) -> function::InlineFunctionId { function::InlineFunctionId(self.functions.len() - 1) } + + fn placeholder_function() -> function::InlineFunction { + let mut chunk = Vec::new(); + encode_instruction(Instruction::Return, &mut chunk); + function::InlineFunction { + name: "".to_string(), + signature: function::Signature::new(Vec::new(), None), + constants: Vec::new(), + steps: Vec::new(), + cast_types: Vec::new(), + sequence_types: Vec::new(), + closure_names: Vec::new(), + chunk, + spans: vec![SourceSpan::empty()], + } + } } /// Given a function provide information about it. diff --git a/xee-interpreter/src/interpreter/runnable.rs b/xee-interpreter/src/interpreter/runnable.rs index e5c8bfc42..89c4531d5 100644 --- a/xee-interpreter/src/interpreter/runnable.rs +++ b/xee-interpreter/src/interpreter/runnable.rs @@ -1,5 +1,6 @@ use std::rc::Rc; +use ahash::AHashMap; use ibig::ibig; use iri_string::types::IriReferenceStr; use xot::Xot; @@ -8,7 +9,7 @@ use crate::context::DocumentsRef; use crate::context::DynamicContext; use crate::context::StaticContext; use crate::error::SpannedError; -use crate::function::Function; +use crate::function::{Function, InlineFunctionData}; use crate::interpreter::interpret::ContextInfo; use crate::sequence; use crate::stack; @@ -37,8 +38,8 @@ impl<'a> Runnable<'a> { } fn run_value(&self, xot: &'a mut Xot) -> error::SpannedResult { - let arguments = self.dynamic_context.arguments().unwrap(); - let mut interpreter = Interpreter::new(self, xot); + let arguments = self.resolve_global_param_arguments(xot)?; + let mut interpreter = Interpreter::new(self, xot, self.dynamic_context.type_table()); let context_info = if let Some(context_item) = self.dynamic_context.context_item() { ContextInfo { @@ -78,11 +79,134 @@ impl<'a> Runnable<'a> { } } + fn resolve_global_param_arguments( + &self, + xot: &'a mut Xot, + ) -> error::SpannedResult> { + if self.program.declarations.global_params.is_empty() { + return Ok(self.dynamic_context.arguments()?); + } + let globals = &self.program.declarations.global_params; + let mut explicit: AHashMap = + AHashMap::new(); + for (name, value) in self.dynamic_context.variables() { + explicit.insert(name.clone(), value.clone()); + } + let mut values: AHashMap = AHashMap::new(); + for global in globals { + if global.overrideable { + if let Some(value) = explicit.get(&global.name) { + values.insert(global.name.clone(), value.clone()); + } + } + } + let iterations = globals.len().max(1); + + for _ in 0..iterations { + for global in globals { + let value = if global.overrideable { + if let Some(value) = explicit.get(&global.name) { + value.clone() + } else if let Some(default_fn) = global.default { + let args = globals + .iter() + .map(|param| values.get(¶m.name).cloned().unwrap_or_default()) + .collect::>(); + let function = InlineFunctionData::new(default_fn, Vec::new()).into(); + let mut interpreter = Interpreter::new(self, xot, self.dynamic_context.type_table()); + interpreter + .call_function_with_arguments(&function, &args) + .map_err(|error| error::SpannedError { + error, + span: Some(self.program.span().into()), + })? + } else if global.required { + return Err(error::SpannedError { + error: error::Error::XTDE0050, + span: Some(self.program.span().into()), + }); + } else { + sequence::Sequence::default() + } + } else if let Some(default_fn) = global.default { + let args = globals + .iter() + .map(|param| values.get(¶m.name).cloned().unwrap_or_default()) + .collect::>(); + let function = InlineFunctionData::new(default_fn, Vec::new()).into(); + let mut interpreter = Interpreter::new(self, xot, self.dynamic_context.type_table()); + interpreter + .call_function_with_arguments(&function, &args) + .map_err(|error| error::SpannedError { + error, + span: Some(self.program.span().into()), + })? + } else if global.required { + return Err(error::SpannedError { + error: error::Error::XTDE0050, + span: Some(self.program.span().into()), + }); + } else { + sequence::Sequence::default() + }; + + values.insert(global.name.clone(), value); + } + } + + let mut resolved = Vec::with_capacity(globals.len()); + for global in globals { + resolved.push(values.get(&global.name).cloned().unwrap_or_default()); + } + Ok(resolved) + } + /// Run the program against a sequence item. pub fn many(&self, xot: &'a mut Xot) -> error::SpannedResult { Ok(self.run_value(xot)?.try_into()?) } + pub fn call_named_template( + &self, + xot: &'a mut Xot, + name: &xot::xmlname::OwnedName, + params: Option<&AHashMap>, + ) -> error::SpannedResult { + let function_id = self + .program + .declarations + .named_templates + .get(name) + .copied() + .ok_or(SpannedError { + error: error::Error::Unsupported(String::from("Named template not found")), + span: Some(self.program.span().into()), + })?; + + let arguments = self.resolve_global_param_arguments(xot)?; + let mut interpreter = Interpreter::new(self, xot, self.dynamic_context.type_table()); + let context_info = if let Some(context_item) = self.dynamic_context.context_item() { + ContextInfo { + item: context_item.clone().into(), + position: ibig!(1).into(), + size: ibig!(1).into(), + } + } else { + ContextInfo { + item: stack::Value::Absent, + position: stack::Value::Absent, + size: stack::Value::Absent, + } + }; + interpreter.start(context_info, arguments); + interpreter + .call_named_template(function_id, params) + .map_err(|error| SpannedError { + error, + span: Some(self.program.span().into()), + }) + } + /// Run the program, expect a single item as the result. pub fn one(&self, xot: &'a mut Xot) -> error::SpannedResult { let sequence = self.many(xot)?; diff --git a/xee-interpreter/src/interpreter/state.rs b/xee-interpreter/src/interpreter/state.rs index 22f250532..2079b9b83 100644 --- a/xee-interpreter/src/interpreter/state.rs +++ b/xee-interpreter/src/interpreter/state.rs @@ -8,8 +8,10 @@ use xot::Xot; use crate::error; use crate::function; +use crate::pattern::{ModeId, RuleEntry}; use crate::sequence; use crate::stack; +use crate::context::TypeTableRef; const FRAMES_MAX: usize = 64; @@ -18,6 +20,9 @@ pub(crate) struct Frame { function: function::InlineFunctionId, base: usize, pub(crate) ip: usize, + pub(crate) mode: Option, + pub(crate) rule: Option, + pub(crate) rule_index: Option, } impl Frame { @@ -41,15 +46,26 @@ pub struct State<'a> { build_stack: Vec, frames: ArrayVec, regex_cache: RefCell>>, + global_params: Vec, pub(crate) xot: &'a mut Xot, + pub(crate) type_table: TypeTableRef, } -#[derive(Debug)] +#[derive(Debug, Clone)] +pub(crate) struct StateSnapshot { + stack: Vec, + build_stack: Vec, + frames: ArrayVec, + xot: Option, + type_table: Option, +} + +#[derive(Debug, Clone)] struct ItemBuildStackEntry { build_stack: Vec, } -#[derive(Debug)] +#[derive(Debug, Clone)] struct BuildStackEntry { item: ItemBuildStackEntry, } @@ -88,10 +104,87 @@ impl<'a> State<'a> { build_stack: vec![], frames: ArrayVec::new(), regex_cache: RefCell::new(HashMap::new()), + global_params: Vec::new(), xot, + type_table: TypeTableRef::new(), } } + pub(crate) fn new_with_type_table(xot: &'a mut Xot, type_table: TypeTableRef) -> Self { + Self { + stack: vec![], + build_stack: vec![], + frames: ArrayVec::new(), + regex_cache: RefCell::new(HashMap::new()), + global_params: Vec::new(), + xot, + type_table, + } + } + + pub(crate) fn snapshot(&self, include_xot: bool) -> StateSnapshot { + StateSnapshot { + stack: self.stack.clone(), + build_stack: self.build_stack.clone(), + frames: self.frames.clone(), + xot: include_xot.then(|| self.xot.clone()), + type_table: include_xot.then(|| self.type_table.borrow().clone()), + } + } + + pub(crate) fn restore(&mut self, snapshot: StateSnapshot) { + self.stack = snapshot.stack; + self.build_stack = snapshot.build_stack; + self.frames = snapshot.frames; + if let Some(xot) = snapshot.xot { + *self.xot = xot; + } + if let Some(type_table) = snapshot.type_table { + *self.type_table.borrow_mut() = type_table; + } + } + + pub(crate) fn set_global_params(&mut self, params: Vec) { + self.global_params = params; + } + + pub(crate) fn set_node_type(&mut self, node: xot::Node, xs: xee_schema_type::Xs) { + self.type_table.borrow_mut().set(node, xs); + } + + pub(crate) fn node_type(&self, node: xot::Node) -> Option { + self.type_table.borrow().get(node) + } + + pub(crate) fn clone_node_with_type(&mut self, node: xot::Node) -> xot::Node { + let cloned = self.xot.clone_node(node); + self.type_table.borrow_mut().copy_type(node, cloned); + cloned + } + + pub(crate) fn global_params(&self) -> &[sequence::Sequence] { + &self.global_params + } + + pub(crate) fn context_args( + &self, + ) -> error::Result<(sequence::Sequence, sequence::Sequence, sequence::Sequence)> { + let base = self.frame().base(); + let item = sequence::Sequence::try_from(&self.stack[base])?; + let position = sequence::Sequence::try_from(&self.stack[base + 1])?; + let size = sequence::Sequence::try_from(&self.stack[base + 2])?; + Ok((item, position, size)) + } + + pub(crate) fn context_values(&self) -> (stack::Value, stack::Value, stack::Value) { + let base = self.frame().base(); + ( + self.stack[base].clone(), + self.stack[base + 1].clone(), + self.stack[base + 2].clone(), + ) + } + pub(crate) fn push(&mut self, sequence: T) where T: Into, @@ -169,6 +262,9 @@ impl<'a> State<'a> { function: function_id, ip: 0, base: 0, + mode: None, + rule: None, + rule_index: None, }); } @@ -184,6 +280,31 @@ impl<'a> State<'a> { function: function_id, ip: 0, base: self.stack.len() - arity, + mode: None, + rule: None, + rule_index: None, + }); + Ok(()) + } + + pub(crate) fn push_frame_with_rule( + &mut self, + function_id: function::InlineFunctionId, + arity: usize, + mode: ModeId, + rule: RuleEntry, + rule_index: usize, + ) -> error::Result<()> { + if self.frames.len() >= self.frames.capacity() { + return Err(error::Error::StackOverflow); + } + self.frames.push(Frame { + function: function_id, + ip: 0, + base: self.stack.len() - arity, + mode: Some(mode), + rule: Some(rule), + rule_index: Some(rule_index), }); Ok(()) } @@ -275,3 +396,42 @@ impl<'a> State<'a> { self.xot } } + +#[cfg(test)] +mod tests { + use super::State; + use xee_schema_type::Xs; + use xot::Xot; + + #[test] + fn snapshot_restores_type_table_with_xot() { + let mut xot = Xot::new(); + let doc = xot.parse(r#""#).unwrap(); + let doc_el = xot.document_element(doc).unwrap(); + let a = xot.first_child(doc_el).unwrap(); + let mut state = State::new(&mut xot); + state.set_node_type(a, Xs::String); + let snapshot = state.snapshot(true); + + state.set_node_type(a, Xs::UntypedAtomic); + state.restore(snapshot); + + assert_eq!(state.node_type(a), Some(Xs::String)); + } + + #[test] + fn snapshot_does_not_restore_type_table_without_xot() { + let mut xot = Xot::new(); + let doc = xot.parse(r#""#).unwrap(); + let doc_el = xot.document_element(doc).unwrap(); + let a = xot.first_child(doc_el).unwrap(); + let mut state = State::new(&mut xot); + state.set_node_type(a, Xs::String); + let snapshot = state.snapshot(false); + + state.set_node_type(a, Xs::UntypedAtomic); + state.restore(snapshot); + + assert_eq!(state.node_type(a), Some(Xs::UntypedAtomic)); + } +} diff --git a/xee-interpreter/src/library/json.rs b/xee-interpreter/src/library/json.rs index c5eca13ff..e60e410e2 100644 --- a/xee-interpreter/src/library/json.rs +++ b/xee-interpreter/src/library/json.rs @@ -2,7 +2,10 @@ use xee_schema_type::Xs; use xee_xpath_macros::xpath_fn; use xot::Xot; -use crate::{atomic, context, error, function, interpreter::Interpreter, sequence, wrap_xpath_fn}; +use crate::{ + atomic, context, error, function, interpreter::Interpreter, sequence, wrap_xpath_fn, + xml::TypeTable, +}; use super::StaticFunctionDescription; @@ -27,7 +30,12 @@ fn parse_json2( options: function::Map, ) -> error::Result> { let parameters = - ParseJsonParameters::from_map(&options, context.static_context(), interpreter.xot())?; + ParseJsonParameters::from_map( + &options, + context.static_context(), + interpreter.xot(), + &interpreter.state.type_table.borrow(), + )?; if let Some(json_text) = json_text { let value = json::parse(json_text).map_err(|_| error::Error::FOJS0001)?; @@ -62,8 +70,9 @@ impl ParseJsonParameters { map: &function::Map, static_context: &context::StaticContext, xot: &Xot, + type_table: &TypeTable, ) -> error::Result { - let c = sequence::OptionParameterConverter::new(map, static_context, xot); + let c = sequence::OptionParameterConverter::new(map, static_context, xot, type_table); let liberal = c .option_with_default("liberal", Xs::Boolean, false) diff --git a/xee-interpreter/src/library/map.rs b/xee-interpreter/src/library/map.rs index 5d1126ef7..7e70875a6 100644 --- a/xee-interpreter/src/library/map.rs +++ b/xee-interpreter/src/library/map.rs @@ -98,12 +98,14 @@ impl MergeOptions { // apply function conversion rules as specified by the option parameter // conventions let runnable = interpreter.runnable(); + let type_table = interpreter.state.type_table.borrow(); let duplicates = duplicates .clone() .sequence_type_matching_function_conversion( &sequence_type, runnable.static_context(), interpreter.xot(), + &type_table, &|function| runnable.program().function_info(function).signature(), )?; // take the first value, which should be a string diff --git a/xee-interpreter/src/library/parse.rs b/xee-interpreter/src/library/parse.rs index 14b461a6f..7699ec9e5 100644 --- a/xee-interpreter/src/library/parse.rs +++ b/xee-interpreter/src/library/parse.rs @@ -57,11 +57,16 @@ fn serialize1( arg: &sequence::Sequence, ) -> error::Result { let map = function::Map::new(vec![])?; - let serialization_parameters = sequence::SerializationParameters::from_map( - map, - context.static_context(), - interpreter.xot_mut(), - )?; + let type_table = interpreter.state.type_table.clone(); + let serialization_parameters = { + let type_table = type_table.borrow(); + sequence::SerializationParameters::from_map( + map, + context.static_context(), + interpreter.xot(), + &type_table, + )? + }; arg.serialize(serialization_parameters, interpreter.xot_mut()) } @@ -82,11 +87,16 @@ fn serialize2( } else { function::Map::new(vec![])? }; - let serialization_parameters = sequence::SerializationParameters::from_map( - map, - context.static_context(), - interpreter.xot_mut(), - )?; + let type_table = interpreter.state.type_table.clone(); + let serialization_parameters = { + let type_table = type_table.borrow(); + sequence::SerializationParameters::from_map( + map, + context.static_context(), + interpreter.xot(), + &type_table, + )? + }; arg.serialize(serialization_parameters, interpreter.xot_mut()) } diff --git a/xee-interpreter/src/pattern/mod.rs b/xee-interpreter/src/pattern/mod.rs index a68aaf1ad..aef01429a 100644 --- a/xee-interpreter/src/pattern/mod.rs +++ b/xee-interpreter/src/pattern/mod.rs @@ -5,5 +5,5 @@ mod mode; mod pattern_core; mod pattern_lookup; -pub use mode::{ModeId, ModeLookup}; +pub use mode::{ModeId, RuleEntry}; pub(crate) use pattern_core::PredicateMatcher; diff --git a/xee-interpreter/src/pattern/mode.rs b/xee-interpreter/src/pattern/mode.rs index e1c26c1df..4c86eec87 100644 --- a/xee-interpreter/src/pattern/mode.rs +++ b/xee-interpreter/src/pattern/mode.rs @@ -1,10 +1,13 @@ -use ahash::{HashMap, HashMapExt}; - -use xee_xpath_ast::Pattern; - use crate::function; -use super::pattern_lookup::PatternLookup; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RuleEntry { + pub function_id: function::InlineFunctionId, + pub priority: rust_decimal::Decimal, + pub import_level: u32, + pub declaration_order: i64, + pub is_builtin: bool, +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ModeId(usize); @@ -18,35 +21,3 @@ impl ModeId { self.0 } } - -#[derive(Debug, Default)] -pub struct ModeLookup { - pub(crate) modes: HashMap>, -} - -impl ModeLookup { - pub(crate) fn new() -> Self { - Self { - modes: HashMap::new(), - } - } - - pub(crate) fn lookup( - &self, - mode: ModeId, - mut matches: impl FnMut(&Pattern) -> bool, - ) -> Option<&V> { - let pattern_lookup = self.modes.get(&mode)?; - pattern_lookup.lookup(&mut matches) - } - - pub fn add_rules( - &mut self, - mode: ModeId, - rules: Vec<(Pattern, V)>, - ) { - let pattern_lookup = self.modes.entry(mode).or_insert_with(PatternLookup::new); - - pattern_lookup.add_rules(rules); - } -} diff --git a/xee-interpreter/src/pattern/pattern_core.rs b/xee-interpreter/src/pattern/pattern_core.rs index 6fd955143..91255585e 100644 --- a/xee-interpreter/src/pattern/pattern_core.rs +++ b/xee-interpreter/src/pattern/pattern_core.rs @@ -6,6 +6,8 @@ use xee_xpath_ast::pattern; use crate::function::InlineFunctionId; use crate::sequence::Item; use crate::xml; +use crate::context::TypeTableRef; +use crate::xml::TypeTable; pub(crate) enum NodeMatch { Match(Option), @@ -15,6 +17,7 @@ pub(crate) enum NodeMatch { pub(crate) trait PredicateMatcher { fn match_predicate(&mut self, inline_function_id: InlineFunctionId, item: &Item) -> bool; fn xot(&self) -> &Xot; + fn type_table(&self) -> &TypeTableRef; fn matches(&mut self, pattern: &pattern::Pattern, item: &Item) -> bool { match pattern { @@ -213,8 +216,11 @@ pub(crate) trait PredicateMatcher { } else if step.forward == pattern::ForwardAxis::Attribute { return (false, step.forward); } - if !Self::matches_node_test(&step.node_test, node, self.xot()) { - return (false, step.forward); + { + let type_table = self.type_table().borrow(); + if !Self::matches_node_test(&step.node_test, node, self.xot(), &type_table) { + return (false, step.forward); + } } // if we have a match, check whether the predicates apply let item = Item::Node(node); @@ -243,10 +249,17 @@ pub(crate) trait PredicateMatcher { true } - fn matches_node_test(node_test: &pattern::NodeTest, node: xot::Node, xot: &Xot) -> bool { + fn matches_node_test( + node_test: &pattern::NodeTest, + node: xot::Node, + xot: &Xot, + type_table: &TypeTable, + ) -> bool { match node_test { pattern::NodeTest::NameTest(name_test) => Self::matches_name_test(name_test, node, xot), - pattern::NodeTest::KindTest(kind_test) => Self::matches_kind_test(kind_test, node, xot), + pattern::NodeTest::KindTest(kind_test) => { + Self::matches_kind_test(kind_test, node, xot, type_table) + } } } @@ -260,7 +273,7 @@ pub(crate) trait PredicateMatcher { false } } - pattern::NameTest::Star => true, + pattern::NameTest::Star => xot.node_name_ref(node).unwrap().is_some(), pattern::NameTest::LocalName(expected_local_name) => { if let Some(name) = xot.node_name(node) { xot.local_name_str(name) == expected_local_name @@ -279,8 +292,13 @@ pub(crate) trait PredicateMatcher { } } - fn matches_kind_test(kind_test: &KindTest, node: xot::Node, xot: &Xot) -> bool { - xml::kind_test(kind_test, xot, node) + fn matches_kind_test( + kind_test: &KindTest, + node: xot::Node, + xot: &Xot, + type_table: &TypeTable, + ) -> bool { + xml::kind_test(kind_test, xot, type_table, node) } } @@ -313,6 +331,7 @@ mod tests { struct BasicPredicateMatcher<'a> { xot: &'a Xot, + type_table: TypeTableRef, predicate_matches: bool, } @@ -320,6 +339,7 @@ mod tests { fn new(xot: &'a Xot) -> Self { Self { xot, + type_table: TypeTableRef::new(), predicate_matches: false, } } @@ -327,6 +347,7 @@ mod tests { fn matching(xot: &'a Xot) -> Self { Self { xot, + type_table: TypeTableRef::new(), predicate_matches: true, } } @@ -340,6 +361,10 @@ mod tests { fn xot(&self) -> &Xot { self.xot } + + fn type_table(&self) -> &TypeTableRef { + &self.type_table + } } #[test] diff --git a/xee-interpreter/src/pattern/pattern_lookup.rs b/xee-interpreter/src/pattern/pattern_lookup.rs index 7e08de6b5..b2b647ed5 100644 --- a/xee-interpreter/src/pattern/pattern_lookup.rs +++ b/xee-interpreter/src/pattern/pattern_lookup.rs @@ -5,6 +5,7 @@ use crate::function; use crate::interpreter::Interpreter; use crate::pattern::pattern_core::PredicateMatcher; use crate::sequence::Item; +use crate::context::TypeTableRef; #[derive(Debug, Default)] pub struct PatternLookup { @@ -60,6 +61,10 @@ impl PredicateMatcher for Interpreter<'_> { fn xot(&self) -> &Xot { self.xot() } + + fn type_table(&self) -> &TypeTableRef { + &self.state.type_table + } } impl PatternLookup { @@ -82,4 +87,35 @@ impl PatternLookup { .find(|(pattern, _)| matches(pattern)) .map(|(_, value)| value) } + + pub(crate) fn collect_matching( + &self, + mut matches: impl FnMut(&Pattern) -> bool, + ) -> Vec + where + V: Clone, + { + self.patterns + .iter() + .filter_map(|(pattern, value)| { + if matches(pattern) { + Some(value.clone()) + } else { + None + } + }) + .collect() + } + + pub(crate) fn lookup_with_filter( + &self, + mut matches: impl FnMut(&Pattern) -> bool, + mut accept: impl FnMut(&V) -> bool, + ) -> Option<&V> { + self.patterns + .iter() + .find(|(pattern, value)| accept(value) && matches(pattern)) + .map(|(_, value)| value) + } + } diff --git a/xee-interpreter/src/sequence/matching.rs b/xee-interpreter/src/sequence/matching.rs index 955885ffe..fca892d6a 100644 --- a/xee-interpreter/src/sequence/matching.rs +++ b/xee-interpreter/src/sequence/matching.rs @@ -14,6 +14,7 @@ use crate::context; use crate::error; use crate::function; use crate::xml; +use crate::xml::TypeTable; use super::core::Sequence; use super::item::Item; @@ -26,13 +27,14 @@ impl Sequence { &self, s: &str, xot: &Xot, + type_table: &TypeTable, get_signature: &impl Fn(&function::Function) -> &'a function::Signature, ) -> error::Result { let namespaces = Namespaces::default(); let sequence_type = parse_sequence_type(s, &namespaces)?; if self .clone() - .sequence_type_matching(&sequence_type, xot, get_signature) + .sequence_type_matching(&sequence_type, xot, type_table, get_signature) .is_ok() { Ok(true) @@ -46,6 +48,7 @@ impl Sequence { self, sequence_type: &ast::SequenceType, xot: &Xot, + type_table: &TypeTable, get_signature: &impl Fn(&function::Function) -> &'a function::Signature, ) -> error::Result { self.sequence_type_matching_convert( @@ -53,6 +56,7 @@ impl Sequence { &|atomic, _| Ok(atomic), &|function_test, item| item.function_type_matching(function_test, &get_signature), xot, + type_table, ) } @@ -62,6 +66,7 @@ impl Sequence { sequence_type: &ast::SequenceType, context: &'a context::StaticContext, xot: &Xot, + type_table: &TypeTable, get_signature: &impl Fn(&function::Function) -> &'a function::Signature, ) -> error::Result { self.sequence_type_matching_convert( @@ -69,6 +74,7 @@ impl Sequence { &|atomic, xs| Self::cast_or_promote_atomic(atomic, xs, context), &|function_test, item| item.function_arity_matching(function_test, &get_signature), xot, + type_table, ) } @@ -99,6 +105,7 @@ impl Sequence { cast_or_promote_atomic: &impl Fn(atomic::Atomic, Xs) -> error::Result, check_function: &impl Fn(&ast::FunctionTest, &Item) -> error::Result<()>, xot: &Xot, + type_table: &TypeTable, ) -> error::Result { match t { ast::SequenceType::Empty => { @@ -113,6 +120,7 @@ impl Sequence { cast_or_promote_atomic, check_function, xot, + type_table, ), } } @@ -123,6 +131,7 @@ impl Sequence { cast_or_promote_atomic: &impl Fn(atomic::Atomic, Xs) -> error::Result, check_function: &impl Fn(&ast::FunctionTest, &Item) -> error::Result<()>, xot: &Xot, + type_table: &TypeTable, ) -> error::Result { match &occurrence_item.item_type { ast::ItemType::AtomicOrUnionType(xs) => self.atomic_occurrence_item_matching( @@ -136,6 +145,7 @@ impl Sequence { cast_or_promote_atomic, check_function, xot, + type_table, ), } } @@ -149,6 +159,7 @@ impl Sequence { cast_or_promote_atomic: &impl Fn(atomic::Atomic, Xs) -> error::Result, check_function: &impl Fn(&ast::FunctionTest, &Item) -> error::Result<()>, xot: &Xot, + type_table: &TypeTable, ) -> error::Result { match occurrence_item.occurrence { ast::Occurrence::One => { @@ -158,6 +169,7 @@ impl Sequence { cast_or_promote_atomic, check_function, xot, + type_table, )?; } ast::Occurrence::Option => { @@ -168,6 +180,7 @@ impl Sequence { cast_or_promote_atomic, check_function, xot, + type_table, )?; } } @@ -184,6 +197,7 @@ impl Sequence { cast_or_promote_atomic, check_function, xot, + type_table, )?; } } @@ -205,6 +219,7 @@ impl Sequence { cast_or_promote_atomic, check_function, xot, + type_table, )?; } } @@ -264,6 +279,7 @@ impl Item { cast_or_promote_atomic: &impl Fn(atomic::Atomic, Xs) -> error::Result, check_function: &impl Fn(&ast::FunctionTest, &Item) -> error::Result<()>, xot: &Xot, + type_table: &TypeTable, ) -> error::Result<()> { match item_type { ast::ItemType::Item => {} @@ -271,7 +287,7 @@ impl Item { unreachable!() } ast::ItemType::KindTest(kind_test) => { - self.kind_test_matching(kind_test, xot)?; + self.kind_test_matching(kind_test, xot, type_table)?; } ast::ItemType::FunctionTest(function_test) => { check_function(function_test, &self)?; @@ -294,6 +310,7 @@ impl Item { cast_or_promote_atomic, check_function, xot, + type_table, )?; } } @@ -312,6 +329,7 @@ impl Item { cast_or_promote_atomic, check_function, xot, + type_table, )?; } } @@ -320,10 +338,15 @@ impl Item { Ok(()) } - fn kind_test_matching(&self, kind_test: &ast::KindTest, xot: &Xot) -> error::Result<()> { + fn kind_test_matching( + &self, + kind_test: &ast::KindTest, + xot: &Xot, + type_table: &TypeTable, + ) -> error::Result<()> { match self { Item::Node(node) => { - if xml::kind_test(kind_test, xot, *node) { + if xml::kind_test(kind_test, xot, type_table, *node) { Ok(()) } else { Err(error::Error::XPTY0004) @@ -451,19 +474,21 @@ mod tests { let wrong_amount_sequence: Sequence = vec![ibig!(1), ibig!(2)].into(); let wrong_type_sequence: Sequence = vec![false].into(); let xot = Xot::new(); + let type_table = TypeTable::new(); let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(&right_result.unwrap(), &right_sequence); let wrong_amount_result = - wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &type_table, &|_| unreachable!()); assert_eq!(wrong_amount_result, Err(error::Error::XPTY0004)); let wrong_type_result = - wrong_type_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + wrong_type_sequence.sequence_type_matching(&sequence_type, &xot, &type_table, &|_| unreachable!()); assert_eq!(wrong_type_result, Err(error::Error::XPTY0004)); } @@ -476,18 +501,20 @@ mod tests { let wrong_amount_sequence = Sequence::from(vec![Item::from(1i64), Item::from(1i64)]); let wrong_type_sequence = Sequence::from(vec![Item::from(atomic::Atomic::from(false))]); let xot = Xot::new(); + let type_table = TypeTable::new(); let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); let wrong_amount_result = - wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &type_table, &|_| unreachable!()); assert_eq!(wrong_amount_result, Err(error::Error::XPTY0004)); let wrong_type_result = - wrong_type_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + wrong_type_sequence.sequence_type_matching(&sequence_type, &xot, &type_table, &|_| unreachable!()); assert_eq!(wrong_type_result, Err(error::Error::XPTY0004)); } @@ -501,19 +528,26 @@ mod tests { Sequence::from(vec![Item::from(ibig!(1)), Item::from(ibig!(2))]); let right_type_sequence2 = Sequence::from(vec![Item::from(atomic::Atomic::from(false))]); let xot = Xot::new(); + let type_table = TypeTable::new(); let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); - let wrong_amount_result = - wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + let wrong_amount_result = wrong_amount_sequence.sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| unreachable!(), + ); assert_eq!(wrong_amount_result, Err(error::Error::XPTY0004)); let right_type_result2 = right_type_sequence2.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_type_result2, Ok(right_type_sequence2)); @@ -532,20 +566,27 @@ mod tests { Item::from(atomic::Atomic::from(2i64)), ]); let right_type_sequence2 = Sequence::from(vec![Item::from(node)]); + let type_table = TypeTable::new(); let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); - let wrong_amount_result = - wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + let wrong_amount_result = wrong_amount_sequence.sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| unreachable!(), + ); assert_eq!(wrong_amount_result, Err(error::Error::XPTY0004)); let right_type_result2 = right_type_sequence2.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_type_result2, Ok(right_type_sequence2)); @@ -561,19 +602,26 @@ mod tests { Sequence::from(vec![Item::from(ibig!(1)), Item::from(ibig!(2))]); let right_empty_sequence = Sequence::default(); let xot = Xot::new(); + let type_table = TypeTable::new(); let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); - let wrong_amount_result = - wrong_amount_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + let wrong_amount_result = wrong_amount_sequence.sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| unreachable!(), + ); assert_eq!(wrong_amount_result, Err(error::Error::XPTY0004)); let right_empty_result = right_empty_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_empty_result, Ok(right_empty_sequence)); @@ -588,10 +636,12 @@ mod tests { let right_multi_sequence = Sequence::from(vec![Item::from(ibig!(1)), Item::from(ibig!(2))]); let right_empty_sequence = Sequence::default(); let xot = Xot::new(); + let type_table = TypeTable::new(); let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); @@ -599,6 +649,7 @@ mod tests { let right_multi_result = right_multi_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_multi_result, Ok(right_multi_sequence)); @@ -606,6 +657,7 @@ mod tests { let right_empty_result = right_empty_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_empty_result, Ok(right_empty_sequence)); @@ -626,6 +678,7 @@ mod tests { .attributes(a) .get_node(xot.name("attr").unwrap()) .unwrap(); + let type_table = TypeTable::new(); let right_sequence = Sequence::from(vec![ Item::from(doc), @@ -640,12 +693,17 @@ mod tests { let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); - let wrong_result = - wrong_sequence.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + let wrong_result = wrong_sequence.sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| unreachable!(), + ); assert_eq!(wrong_result, Err(error::Error::XPTY0004)); } @@ -664,6 +722,7 @@ mod tests { .attributes(a) .get_node(xot.name("attr").unwrap()) .unwrap(); + let type_table = TypeTable::new(); let right_sequence = Sequence::from(vec![Item::from(doc), Item::from(a), Item::from(b)]); @@ -673,15 +732,24 @@ mod tests { let right_result = right_sequence.clone().sequence_type_matching( &sequence_type, &xot, + &type_table, &|_| unreachable!(), ); assert_eq!(right_result, Ok(right_sequence)); - let wrong_result = - wrong_sequence_text.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + let wrong_result = wrong_sequence_text.sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| unreachable!(), + ); assert_eq!(wrong_result, Err(error::Error::XPTY0004)); - let wrong_result = - wrong_sequence_attr.sequence_type_matching(&sequence_type, &xot, &|_| unreachable!()); + let wrong_result = wrong_sequence_attr.sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| unreachable!(), + ); assert_eq!(wrong_result, Err(error::Error::XPTY0004)); } @@ -695,10 +763,12 @@ mod tests { let static_context = context::StaticContext::default(); let xot = Xot::new(); + let type_table = TypeTable::new(); let right_result = right_sequence.sequence_type_matching_function_conversion( &sequence_type, &static_context, &xot, + &type_table, &|_| unreachable!(), ); // atomization has changed the result sequence @@ -725,11 +795,13 @@ mod tests { let right_sequence = Sequence::from(vec![Item::from(a), Item::from(b)]); let static_context = context::StaticContext::default(); + let type_table = TypeTable::new(); let right_result = right_sequence.sequence_type_matching_function_conversion( &sequence_type, &static_context, &xot, + &type_table, &|_| unreachable!(), ); // atomization has changed the result sequence @@ -758,11 +830,14 @@ mod tests { ); let xot = Xot::new(); + let type_table = TypeTable::new(); - let right_result = - right_sequence - .clone() - .sequence_type_matching(&sequence_type, &xot, &|_| &signature); + let right_result = right_sequence.clone().sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| &signature, + ); assert_eq!(&right_result.unwrap(), &right_sequence); } @@ -783,11 +858,14 @@ mod tests { ); let xot = Xot::new(); + let type_table = TypeTable::new(); - let right_result = - right_sequence - .clone() - .sequence_type_matching(&sequence_type, &xot, &|_| &signature); + let right_result = right_sequence.clone().sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| &signature, + ); assert_eq!(&right_result.unwrap(), &right_sequence); } @@ -808,11 +886,14 @@ mod tests { ); let xot = Xot::new(); + let type_table = TypeTable::new(); - let right_result = - right_sequence - .clone() - .sequence_type_matching(&sequence_type, &xot, &|_| &signature); + let right_result = right_sequence.clone().sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| &signature, + ); assert_eq!(&right_result.unwrap(), &right_sequence); } @@ -834,11 +915,14 @@ mod tests { ); let xot = Xot::new(); + let type_table = TypeTable::new(); - let wrong_result = - wrong_sequence - .clone() - .sequence_type_matching(&sequence_type, &xot, &|_| &signature); + let wrong_result = wrong_sequence.clone().sequence_type_matching( + &sequence_type, + &xot, + &type_table, + &|_| &signature, + ); assert_eq!(wrong_result, Err(error::Error::XPTY0004)); } } diff --git a/xee-interpreter/src/sequence/opc.rs b/xee-interpreter/src/sequence/opc.rs index a2c2c1cc8..f8fdca74e 100644 --- a/xee-interpreter/src/sequence/opc.rs +++ b/xee-interpreter/src/sequence/opc.rs @@ -6,6 +6,7 @@ use xee_xpath_ast::ast; use xot::xmlname::{NameStrInfo, OwnedName}; use xot::Xot; +use crate::xml::TypeTable; use crate::{atomic, context, error, function::Map}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -35,6 +36,7 @@ pub(crate) struct OptionParameterConverter<'a> { map: &'a Map, static_context: &'a context::StaticContext, xot: &'a Xot, + type_table: &'a TypeTable, } impl<'a> OptionParameterConverter<'a> { @@ -42,11 +44,13 @@ impl<'a> OptionParameterConverter<'a> { map: &'a Map, static_context: &'a context::StaticContext, xot: &'a Xot, + type_table: &'a TypeTable, ) -> Self { Self { map, static_context, xot, + type_table, } } @@ -61,6 +65,7 @@ impl<'a> OptionParameterConverter<'a> { atomic_type, self.static_context, self.xot, + self.type_table, )?; let value = if let Some(value) = value { value.option()? @@ -102,6 +107,7 @@ impl<'a> OptionParameterConverter<'a> { atomic_type, self.static_context, self.xot, + self.type_table, )?; let values = if let Some(value) = value { value diff --git a/xee-interpreter/src/sequence/serialization.rs b/xee-interpreter/src/sequence/serialization.rs index 3c3157be8..9e7280c86 100644 --- a/xee-interpreter/src/sequence/serialization.rs +++ b/xee-interpreter/src/sequence/serialization.rs @@ -7,6 +7,7 @@ use xee_schema_type::Xs; use crate::{ atomic, context, error, function::{self, Map}, + xml::TypeTable, }; use super::{ @@ -72,8 +73,9 @@ impl SerializationParameters { map: Map, static_context: &context::StaticContext, xot: &Xot, + type_table: &TypeTable, ) -> error::Result { - let c = OptionParameterConverter::new(&map, static_context, xot); + let c = OptionParameterConverter::new(&map, static_context, xot, type_table); let allow_duplicate_names = c.option_with_default("allow-duplicate-names", Xs::Boolean, false)?; @@ -426,6 +428,13 @@ mod tests { use super::*; + fn params_for(map: Map) -> SerializationParameters { + let static_context = context::StaticContext::default(); + let xot = Xot::new(); + let type_table = TypeTable::new(); + SerializationParameters::from_map(map, &static_context, &xot, &type_table).unwrap() + } + #[test] fn test_allow_duplicate_names_true() { let map = Map::new(vec![( @@ -433,9 +442,7 @@ mod tests { sequence::Sequence::from(vec![atomic::Atomic::Boolean(true)]), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert!(params.allow_duplicate_names); } @@ -446,9 +453,7 @@ mod tests { sequence::Sequence::from(vec![atomic::Atomic::Boolean(false)]), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert!(!params.allow_duplicate_names); } @@ -459,18 +464,14 @@ mod tests { sequence::Sequence::default(), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert!(!params.allow_duplicate_names); } #[test] fn test_allow_duplicate_names_missing() { let map = Map::new(vec![]).unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert!(!params.allow_duplicate_names); } @@ -486,9 +487,7 @@ mod tests { ]), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert_eq!(params.cdata_section_elements.len(), 2); assert_eq!(params.cdata_section_elements[0], html); assert_eq!(params.cdata_section_elements[1], script); @@ -502,9 +501,7 @@ mod tests { sequence::Sequence::from(vec![json]), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert_eq!( params.json_node_output_method, QNameOrString::String("json".to_string()) @@ -520,9 +517,7 @@ mod tests { sequence::Sequence::from(vec![json]), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert_eq!( params.json_node_output_method, QNameOrString::QName(owned_name) @@ -536,9 +531,7 @@ mod tests { sequence::Sequence::default(), )]) .unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert_eq!( params.json_node_output_method, QNameOrString::String("xml".to_string()) @@ -548,9 +541,7 @@ mod tests { #[test] fn test_qname_or_string_default_missing() { let map = Map::new(vec![]).unwrap(); - let static_context = context::StaticContext::default(); - let xot = Xot::new(); - let params = SerializationParameters::from_map(map, &static_context, &xot).unwrap(); + let params = params_for(map); assert_eq!( params.json_node_output_method, QNameOrString::String("xml".to_string()) diff --git a/xee-interpreter/src/xml/kind_test.rs b/xee-interpreter/src/xml/kind_test.rs index 78061ac7b..fcd15d4fe 100644 --- a/xee-interpreter/src/xml/kind_test.rs +++ b/xee-interpreter/src/xml/kind_test.rs @@ -3,14 +3,21 @@ use xot::Xot; use xee_xpath_ast::ast; -pub(crate) fn kind_test(kind_test: &ast::KindTest, xot: &Xot, node: xot::Node) -> bool { +use super::TypeTable; + +pub(crate) fn kind_test( + kind_test: &ast::KindTest, + xot: &Xot, + type_table: &TypeTable, + node: xot::Node, +) -> bool { match kind_test { - ast::KindTest::Document(dt) => document_test(dt.as_ref(), xot, node), - ast::KindTest::Element(et) => element_test(et.as_ref(), xot, node), + ast::KindTest::Document(dt) => document_test(dt.as_ref(), xot, type_table, node), + ast::KindTest::Element(et) => element_test(et.as_ref(), xot, type_table, node), ast::KindTest::SchemaElement(_set) => { todo!() } - ast::KindTest::Attribute(at) => attribute_test(at.as_ref(), xot, node), + ast::KindTest::Attribute(at) => attribute_test(at.as_ref(), xot, type_table, node), ast::KindTest::SchemaAttribute(_sat) => { todo!() } @@ -35,15 +42,36 @@ pub(crate) fn kind_test(kind_test: &ast::KindTest, xot: &Xot, node: xot::Node) - } } -fn element_test(test: Option<&ast::ElementOrAttributeTest>, xot: &Xot, node: xot::Node) -> bool { - element_or_attribute_test(test, xot, node, |node, xot| xot.is_element(node)) +fn element_test( + test: Option<&ast::ElementOrAttributeTest>, + xot: &Xot, + type_table: &TypeTable, + node: xot::Node, +) -> bool { + element_or_attribute_test(test, xot, type_table, node, |node, xot| xot.is_element(node)) } -fn attribute_test(test: Option<&ast::ElementOrAttributeTest>, xot: &Xot, node: xot::Node) -> bool { - element_or_attribute_test(test, xot, node, |node, xot| xot.is_attribute_node(node)) +fn attribute_test( + test: Option<&ast::ElementOrAttributeTest>, + xot: &Xot, + type_table: &TypeTable, + node: xot::Node, +) -> bool { + element_or_attribute_test( + test, + xot, + type_table, + node, + |node, xot| xot.is_attribute_node(node), + ) } -fn document_test(test: Option<&ast::DocumentTest>, xot: &Xot, node: xot::Node) -> bool { +fn document_test( + test: Option<&ast::DocumentTest>, + xot: &Xot, + type_table: &TypeTable, + node: xot::Node, +) -> bool { if !xot.is_document(node) { return false; } @@ -55,7 +83,9 @@ fn document_test(test: Option<&ast::DocumentTest>, xot: &Xot, node: xot::Node) - let document_element_node = xot.document_element(node).unwrap(); match document_test { - ast::DocumentTest::Element(et) => element_test(et.as_ref(), xot, document_element_node), + ast::DocumentTest::Element(et) => { + element_test(et.as_ref(), xot, type_table, document_element_node) + } ast::DocumentTest::SchemaElement(_set) => { todo!() } @@ -68,6 +98,7 @@ fn document_test(test: Option<&ast::DocumentTest>, xot: &Xot, node: xot::Node) - fn element_or_attribute_test( test: Option<&ast::ElementOrAttributeTest>, xot: &Xot, + type_table: &TypeTable, node: xot::Node, node_type_match: impl Fn(xot::Node, &Xot) -> bool, ) -> bool { @@ -95,7 +126,7 @@ fn element_or_attribute_test( } // the type also has to match if let Some(type_name) = &test.type_name { - type_annotation(xot, node).derives_from(type_name.name) + type_annotation(xot, type_table, node).derives_from(type_name.name) // ignoring can_be_nilled for now } else { true @@ -106,9 +137,18 @@ fn element_or_attribute_test( } } -fn type_annotation(_xot: &Xot, _node: xot::Node) -> Xs { +fn type_annotation(xot: &Xot, type_table: &TypeTable, node: xot::Node) -> Xs { // for now we don't know any types of nodes yet - Xs::UntypedAtomic + if let Some(xs) = type_table.get(node) { + return xs; + } + if xot.is_element(node) { + Xs::Untyped + } else if xot.is_attribute_node(node) { + Xs::UntypedAtomic + } else { + Xs::UntypedAtomic + } } #[cfg(test)] @@ -123,11 +163,12 @@ mod tests { let doc = xot.parse(r#""#).unwrap(); let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("node()").unwrap(); - assert!(kind_test(&kt, &xot, doc)); - assert!(kind_test(&kt, &xot, doc_el)); - assert!(kind_test(&kt, &xot, a)); + assert!(kind_test(&kt, &xot, &type_table, doc)); + assert!(kind_test(&kt, &xot, &type_table, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, a)); } #[test] @@ -137,12 +178,13 @@ mod tests { let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); let a_text = xot.first_child(a).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("text()").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(!kind_test(&kt, &xot, a)); - assert!(kind_test(&kt, &xot, a_text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(!kind_test(&kt, &xot, &type_table, a)); + assert!(kind_test(&kt, &xot, &type_table, a_text)); } #[test] @@ -151,11 +193,12 @@ mod tests { let doc = xot.parse(r#""#).unwrap(); let doc_el = xot.document_element(doc).unwrap(); let comment = xot.first_child(doc_el).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("comment()").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(kind_test(&kt, &xot, comment)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, comment)); } #[test] @@ -163,9 +206,10 @@ mod tests { let mut xot = Xot::new(); let doc = xot.parse(r#""#).unwrap(); let doc_el = xot.document_element(doc).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("document-node()").unwrap(); - assert!(kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); } #[test] @@ -175,12 +219,13 @@ mod tests { let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); let text = xot.first_child(a).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("element()").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(kind_test(&kt, &xot, doc_el)); - assert!(kind_test(&kt, &xot, a)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(kind_test(&kt, &xot, &type_table, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, a)); + assert!(!kind_test(&kt, &xot, &type_table, text)); } #[test] @@ -190,12 +235,13 @@ mod tests { let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); let text = xot.first_child(a).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("element(*)").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(kind_test(&kt, &xot, doc_el)); - assert!(kind_test(&kt, &xot, a)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(kind_test(&kt, &xot, &type_table, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, a)); + assert!(!kind_test(&kt, &xot, &type_table, text)); } #[test] @@ -205,12 +251,13 @@ mod tests { let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); let text = xot.first_child(a).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("element(a)").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(kind_test(&kt, &xot, a)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, a)); + assert!(!kind_test(&kt, &xot, &type_table, text)); } #[test] @@ -220,16 +267,20 @@ mod tests { let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); let text = xot.first_child(a).unwrap(); + let type_table = TypeTable::new(); + + let kt = parse_kind_test("element(a, xs:untyped)").unwrap(); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(kind_test(&kt, &xot, &type_table, a)); + assert!(!kind_test(&kt, &xot, &type_table, text)); let kt = parse_kind_test("element(a, xs:untypedAtomic)").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(kind_test(&kt, &xot, a)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, a)); // but we're not an xs:string let kt = parse_kind_test("element(a, xs:string)").unwrap(); - assert!(!kind_test(&kt, &xot, a)); + assert!(!kind_test(&kt, &xot, &type_table, a)); } #[test] @@ -245,14 +296,15 @@ mod tests { let text = xot.first_child(a).unwrap(); let alpha = xot.attributes(a).get_node(alpha).unwrap(); let beta = xot.attributes(a).get_node(beta).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("attribute()").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(!kind_test(&kt, &xot, a)); - assert!(kind_test(&kt, &xot, alpha)); - assert!(kind_test(&kt, &xot, beta)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(!kind_test(&kt, &xot, &type_table, a)); + assert!(kind_test(&kt, &xot, &type_table, alpha)); + assert!(kind_test(&kt, &xot, &type_table, beta)); + assert!(!kind_test(&kt, &xot, &type_table, text)); } #[test] @@ -268,14 +320,15 @@ mod tests { let text = xot.first_child(a).unwrap(); let alpha = xot.attributes(a).get_node(alpha).unwrap(); let beta = xot.attributes(a).get_node(beta).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("attribute(alpha)").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(!kind_test(&kt, &xot, a)); - assert!(kind_test(&kt, &xot, alpha)); - assert!(!kind_test(&kt, &xot, beta)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(!kind_test(&kt, &xot, &type_table, a)); + assert!(kind_test(&kt, &xot, &type_table, alpha)); + assert!(!kind_test(&kt, &xot, &type_table, beta)); + assert!(!kind_test(&kt, &xot, &type_table, text)); } #[test] @@ -291,17 +344,18 @@ mod tests { let text = xot.first_child(a).unwrap(); let alpha = xot.attributes(a).get_node(alpha).unwrap(); let beta = xot.attributes(a).get_node(beta).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("attribute(alpha, xs:untypedAtomic)").unwrap(); - assert!(!kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(!kind_test(&kt, &xot, a)); - assert!(kind_test(&kt, &xot, alpha)); - assert!(!kind_test(&kt, &xot, beta)); - assert!(!kind_test(&kt, &xot, text)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(!kind_test(&kt, &xot, &type_table, a)); + assert!(kind_test(&kt, &xot, &type_table, alpha)); + assert!(!kind_test(&kt, &xot, &type_table, beta)); + assert!(!kind_test(&kt, &xot, &type_table, text)); let kt = parse_kind_test("attribute(alpha, xs:string)").unwrap(); - assert!(!kind_test(&kt, &xot, alpha)); + assert!(!kind_test(&kt, &xot, &type_table, alpha)); } #[test] @@ -311,17 +365,18 @@ mod tests { let doc_el = xot.document_element(doc).unwrap(); let a = xot.first_child(doc_el).unwrap(); let text = xot.first_child(a).unwrap(); + let type_table = TypeTable::new(); let kt = parse_kind_test("document-node(element(root))").unwrap(); - assert!(kind_test(&kt, &xot, doc)); - assert!(!kind_test(&kt, &xot, doc_el)); - assert!(!kind_test(&kt, &xot, a)); - assert!(!kind_test(&kt, &xot, text)); + assert!(kind_test(&kt, &xot, &type_table, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc_el)); + assert!(!kind_test(&kt, &xot, &type_table, a)); + assert!(!kind_test(&kt, &xot, &type_table, text)); let kt = parse_kind_test("document-node(element(a))").unwrap(); // the document doesn't match as its root node isn't 'a' - assert!(!kind_test(&kt, &xot, doc)); + assert!(!kind_test(&kt, &xot, &type_table, doc)); // the 'a' node doesn't match either as it's not a document node - assert!(!kind_test(&kt, &xot, a)); + assert!(!kind_test(&kt, &xot, &type_table, a)); } } diff --git a/xee-interpreter/src/xml/mod.rs b/xee-interpreter/src/xml/mod.rs index f68b51624..bb69b5987 100644 --- a/xee-interpreter/src/xml/mod.rs +++ b/xee-interpreter/src/xml/mod.rs @@ -4,6 +4,7 @@ mod document; mod document_order; mod kind_test; mod step; +mod type_table; pub(crate) use base::BaseUriResolver; pub use document::{Document, DocumentHandle, Documents, DocumentsError}; @@ -11,3 +12,4 @@ pub(crate) use document_order::DocumentOrderAccess; pub(crate) use kind_test::kind_test; pub(crate) use step::resolve_step; pub use step::Step; +pub use type_table::TypeTable; diff --git a/xee-interpreter/src/xml/step.rs b/xee-interpreter/src/xml/step.rs index 6a76c9b1e..d777fa691 100644 --- a/xee-interpreter/src/xml/step.rs +++ b/xee-interpreter/src/xml/step.rs @@ -5,6 +5,7 @@ use xee_xpath_ast::ast; use crate::sequence; use super::kind_test::kind_test; +use super::TypeTable; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Step { @@ -12,10 +13,15 @@ pub struct Step { pub node_test: ast::NodeTest, } -pub(crate) fn resolve_step(step: &Step, node: xot::Node, xot: &Xot) -> sequence::Sequence { +pub(crate) fn resolve_step( + step: &Step, + node: xot::Node, + xot: &Xot, + type_table: &TypeTable, +) -> sequence::Sequence { let mut new_items = Vec::new(); for axis_node in node_take_axis(&step.axis, xot, node) { - if node_test(&step.node_test, &step.axis, xot, axis_node) { + if node_test(&step.node_test, &step.axis, xot, type_table, axis_node) { new_items.push(sequence::Item::Node(axis_node)); } } @@ -49,9 +55,15 @@ fn node_take_axis<'a>( xot.axis(axis, node) } -fn node_test(node_test: &ast::NodeTest, axis: &ast::Axis, xot: &Xot, node: xot::Node) -> bool { +fn node_test( + node_test: &ast::NodeTest, + axis: &ast::Axis, + xot: &Xot, + type_table: &TypeTable, + node: xot::Node, +) -> bool { match node_test { - ast::NodeTest::KindTest(kt) => kind_test(kt, xot, node), + ast::NodeTest::KindTest(kt) => kind_test(kt, xot, type_table, node), ast::NodeTest::NameTest(name_test) => { if xot.value_type(node) != principal_node_kind(axis) { return false; @@ -136,12 +148,13 @@ mod tests { let doc_el = xot.document_element(doc)?; let a = xot.first_child(doc_el).unwrap(); let b = xot.next_sibling(a).unwrap(); + let type_table = TypeTable::new(); let step = Step { axis: ast::Axis::Child, node_test: ast::NodeTest::NameTest(ast::NameTest::Star), }; - let value = resolve_step(&step, doc_el, &xot); + let value = resolve_step(&step, doc_el, &xot, &type_table); assert_eq!(value, xot_nodes_to_value(&[a, b])); Ok(()) } @@ -152,6 +165,7 @@ mod tests { let doc = xot.parse(r#""#).unwrap(); let doc_el = xot.document_element(doc)?; let a = xot.first_child(doc_el).unwrap(); + let type_table = TypeTable::new(); let step = Step { axis: ast::Axis::Child, @@ -159,7 +173,7 @@ mod tests { ast::Name::name("a").with_empty_span(), )), }; - let value = resolve_step(&step, doc_el, &xot); + let value = resolve_step(&step, doc_el, &xot, &type_table); assert_eq!(value, xot_nodes_to_value(&[a])); Ok(()) } diff --git a/xee-interpreter/src/xml/type_table.rs b/xee-interpreter/src/xml/type_table.rs new file mode 100644 index 000000000..029139250 --- /dev/null +++ b/xee-interpreter/src/xml/type_table.rs @@ -0,0 +1,35 @@ +use ahash::HashMap; +use xee_schema_type::Xs; + +#[derive(Debug, Clone, Default)] +pub struct TypeTable { + types: HashMap, +} + +impl TypeTable { + pub fn new() -> Self { + Self::default() + } + + pub fn clear(&mut self) { + self.types.clear(); + } + + pub fn get(&self, node: xot::Node) -> Option { + self.types.get(&node).copied() + } + + pub fn set(&mut self, node: xot::Node, xs: Xs) { + self.types.insert(node, xs); + } + + pub fn copy_type(&mut self, from: xot::Node, to: xot::Node) { + if let Some(xs) = self.get(from) { + self.set(to, xs); + } + } + + pub fn remove(&mut self, node: xot::Node) { + self.types.remove(&node); + } +} diff --git a/xee-ir/src/builder.rs b/xee-ir/src/builder.rs index f93e04956..6c25e9917 100644 --- a/xee-ir/src/builder.rs +++ b/xee-ir/src/builder.rs @@ -3,7 +3,7 @@ use xee_xpath_ast::ast; use xee_interpreter::interpreter::instruction::{ encode_instruction, instruction_size, Instruction, }; -use xee_interpreter::{context, function, interpreter, sequence, span, xml}; +use xee_interpreter::{context, declaration, function, interpreter, sequence, span, xml}; use crate::ir; @@ -198,4 +198,21 @@ impl<'a> FunctionBuilder<'a> { ) -> function::InlineFunctionId { self.program.add_function(function) } + + pub(crate) fn set_function( + &mut self, + function_id: function::InlineFunctionId, + function: function::InlineFunction, + ) { + self.program.set_function(function_id, function); + } + + pub(crate) fn add_try_catch(&mut self, entry: declaration::TryCatch) -> u16 { + let id = self.program.declarations.try_catches.len(); + if id > u16::MAX as usize { + panic!("too many try/catch entries"); + } + self.program.declarations.try_catches.push(entry); + id as u16 + } } diff --git a/xee-ir/src/compile.rs b/xee-ir/src/compile.rs index 4a0c5df28..dab3cb865 100644 --- a/xee-ir/src/compile.rs +++ b/xee-ir/src/compile.rs @@ -1,5 +1,10 @@ -use ahash::HashMapExt; -use xee_interpreter::{context::StaticContext, error::SpannedResult, interpreter::Program}; +use ahash::{HashMap, HashMapExt}; +use xee_interpreter::{ + context::StaticContext, + error::SpannedResult, + function, + interpreter::Program, +}; use crate::{ declaration_compiler::{DeclarationCompiler, ModeIds}, @@ -11,7 +16,16 @@ pub fn compile_xpath(expr: ir::ExprS, static_context: StaticContext) -> SpannedR let mut scopes = Scopes::new(); let builder = FunctionBuilder::new(&mut program); let empty_mode_ids = ModeIds::new(); - let mut compiler = FunctionCompiler::new(builder, &mut scopes, &empty_mode_ids); + let empty_user_functions: Vec = Vec::new(); + let empty_named_templates: HashMap = + HashMap::new(); + let mut compiler = FunctionCompiler::new( + builder, + &mut scopes, + &empty_mode_ids, + &empty_user_functions, + &empty_named_templates, + ); compiler.compile_expr(&expr)?; Ok(program) } diff --git a/xee-ir/src/declaration_compiler.rs b/xee-ir/src/declaration_compiler.rs index ca0147c33..3843261d5 100644 --- a/xee-ir/src/declaration_compiler.rs +++ b/xee-ir/src/declaration_compiler.rs @@ -1,30 +1,40 @@ use ahash::{HashMap, HashMapExt, HashSet, HashSetExt}; use rust_decimal::Decimal; -use xee_interpreter::pattern::ModeId; +use xee_interpreter::declaration::{ModeConfig, OnNoMatch}; +use xee_interpreter::pattern::{ModeId, RuleEntry}; use xee_xpath_ast::Pattern; use crate::function_compiler::Scopes; use crate::{ir, FunctionBuilder, FunctionCompiler}; -use xee_interpreter::{error, function, interpreter}; +use xee_interpreter::{ + declaration::{GlobalParam, TemplateParam}, + error, function, interpreter, +}; use xee_xpath_ast::pattern::transform_pattern; #[derive(Debug, Clone)] pub(crate) struct RuleBuilder { priority: Decimal, declaration_order: i64, + import_level: u32, + is_builtin: bool, pattern: Pattern, function_id: function::InlineFunctionId, } impl RuleBuilder { - fn rule( - self, - ) -> ( - Pattern, - function::InlineFunctionId, - ) { - (self.pattern, self.function_id) + fn rule(self) -> (Pattern, RuleEntry) { + ( + self.pattern, + RuleEntry { + function_id: self.function_id, + priority: self.priority, + import_level: self.import_level, + declaration_order: self.declaration_order, + is_builtin: self.is_builtin, + }, + ) } } @@ -36,6 +46,8 @@ pub struct DeclarationCompiler<'a> { rule_declaration_order: i64, rule_builders: HashMap>, mode_ids: ModeIds, + user_function_ids: Vec, + named_template_ids: HashMap, } impl<'a> DeclarationCompiler<'a> { @@ -46,12 +58,20 @@ impl<'a> DeclarationCompiler<'a> { rule_declaration_order: 0, rule_builders: HashMap::new(), mode_ids: HashMap::new(), + user_function_ids: Vec::new(), + named_template_ids: HashMap::new(), } } fn function_compiler(&mut self) -> FunctionCompiler<'_> { let function_builder = FunctionBuilder::new(self.program); - FunctionCompiler::new(function_builder, &mut self.scopes, &self.mode_ids) + FunctionCompiler::new( + function_builder, + &mut self.scopes, + &self.mode_ids, + &self.user_function_ids, + &self.named_template_ids, + ) } pub fn compile_declarations( @@ -61,10 +81,17 @@ impl<'a> DeclarationCompiler<'a> { // first keep track of what modes exist, to create a ModeId for them. We do // this early so any mode reference within apply-templates will resolve. self.compile_modes(declarations); + self.compile_mode_configs(declarations); + + self.prepare_user_function_ids(declarations)?; + self.prepare_named_template_ids(declarations)?; + self.compile_user_functions(declarations)?; + self.compile_named_templates(declarations)?; for rule in &declarations.rules { self.compile_rule(rule)?; } + self.compile_global_params(declarations)?; // now add compiled rules from builder to the program self.add_rules(); let mut function_compiler = self.function_compiler(); @@ -91,18 +118,195 @@ impl<'a> DeclarationCompiler<'a> { self.mode_ids.insert(apply_templates_mode_value, mode_id); } } + for (mode_name, _) in &declarations.modes { + let apply_templates_mode_value = match mode_name { + Some(name) => ir::ApplyTemplatesModeValue::Named(name.clone()), + None => ir::ApplyTemplatesModeValue::Unnamed, + }; + if self.mode_ids.contains_key(&apply_templates_mode_value) { + continue; + } + let mode_id = ModeId::new(self.mode_ids.len()); + self.mode_ids.insert(apply_templates_mode_value, mode_id); + } + } + + fn compile_mode_configs(&mut self, declarations: &ir::Declarations) { + self.program.declarations.mode_configs.clear(); + for (mode_name, mode) in &declarations.modes { + let apply_templates_mode_value = match mode_name { + Some(name) => ir::ApplyTemplatesModeValue::Named(name.clone()), + None => ir::ApplyTemplatesModeValue::Unnamed, + }; + if let Some(mode_id) = self.mode_ids.get(&apply_templates_mode_value).cloned() { + let on_no_match = mode.on_no_match.as_ref().map(|m| match m { + ir::OnNoMatch::DeepCopy => OnNoMatch::DeepCopy, + ir::OnNoMatch::ShallowCopy => OnNoMatch::ShallowCopy, + ir::OnNoMatch::DeepSkip => OnNoMatch::DeepSkip, + ir::OnNoMatch::ShallowSkip => OnNoMatch::ShallowSkip, + ir::OnNoMatch::TextOnlyCopy => OnNoMatch::TextOnlyCopy, + ir::OnNoMatch::Fail => OnNoMatch::Fail, + }); + self.program + .declarations + .mode_configs + .insert(mode_id, ModeConfig { on_no_match }); + } + } + } + + fn prepare_user_function_ids( + &mut self, + declarations: &ir::Declarations, + ) -> error::SpannedResult<()> { + self.user_function_ids.clear(); + if declarations.functions.is_empty() { + self.program.declarations.user_functions.clear(); + return Ok(()); + } + let start_index = self + .program + .reserve_function_slots(declarations.functions.len()); + for offset in 0..declarations.functions.len() { + let index = start_index + offset; + if index > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many user functions".to_string(), + ) + .into()); + } + self.user_function_ids + .push(function::InlineFunctionId::new(index)); + } + self.program + .declarations + .user_functions + .clone_from(&self.user_function_ids); + Ok(()) + } + + fn compile_user_functions( + &mut self, + declarations: &ir::Declarations, + ) -> error::SpannedResult<()> { + if declarations.functions.is_empty() { + return Ok(()); + } + for (index, function_binding) in declarations.functions.iter().enumerate() { + let expected_id = self.user_function_ids[index]; + let mut function_compiler = self.function_compiler(); + function_compiler + .compile_function_id_at(&function_binding.main, expected_id, (0..0).into())?; + } + Ok(()) + } + + fn prepare_named_template_ids( + &mut self, + declarations: &ir::Declarations, + ) -> error::SpannedResult<()> { + self.named_template_ids.clear(); + self.program.declarations.named_templates.clear(); + if declarations.named_templates.is_empty() { + return Ok(()); + } + let start_index = self + .program + .reserve_function_slots(declarations.named_templates.len()); + for (offset, template) in declarations.named_templates.iter().enumerate() { + let index = start_index + offset; + if index > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many named templates".to_string(), + ) + .into()); + } + let id = function::InlineFunctionId::new(index); + if self + .named_template_ids + .insert(template.name.clone(), id) + .is_some() + { + return Err(error::Error::Unsupported( + "Duplicate named template".to_string(), + ) + .into()); + } + } + self.program + .declarations + .named_templates + .clone_from(&self.named_template_ids); + Ok(()) + } + + fn compile_named_templates( + &mut self, + declarations: &ir::Declarations, + ) -> error::SpannedResult<()> { + if declarations.named_templates.is_empty() { + return Ok(()); + } + for template in &declarations.named_templates { + let expected_id = self + .named_template_ids + .get(&template.name) + .copied() + .ok_or_else(|| { + error::Error::Unsupported(String::from("Named template not registered")) + })?; + let mut function_compiler = self.function_compiler(); + function_compiler.compile_function_id_at( + &template.function_definition, + expected_id, + (0..0).into(), + )?; + let compiled_template_params = Self::build_template_params( + &template.template_params, + &template.function_definition, + &mut function_compiler, + )?; + if !compiled_template_params.is_empty() { + self.program + .declarations + .template_params + .insert(expected_id, compiled_template_params); + } + } + Ok(()) } fn compile_rule(&mut self, rule: &ir::Rule) -> error::SpannedResult<()> { - let mut function_compiler = self.function_compiler(); - let function_id = - function_compiler.compile_function_id(&rule.function_definition, (0..0).into())?; + let (function_id, pattern, compiled_template_params) = { + let mut function_compiler = self.function_compiler(); + let function_id = + function_compiler.compile_function_id(&rule.function_definition, (0..0).into())?; - let pattern = transform_pattern(&rule.pattern, |function_definition| { - function_compiler.compile_function_id(function_definition, (0..0).into()) - })?; + let pattern = transform_pattern(&rule.pattern, |function_definition| { + function_compiler.compile_function_id(function_definition, (0..0).into()) + })?; - self.add_rule(&rule.modes, rule.priority, &pattern, function_id); + let compiled_template_params = Self::build_template_params( + &rule.template_params, + &rule.function_definition, + &mut function_compiler, + )?; + (function_id, pattern, compiled_template_params) + }; + if !compiled_template_params.is_empty() { + self.program + .declarations + .template_params + .insert(function_id, compiled_template_params); + } + self.add_rule( + &rule.modes, + rule.priority, + rule.import_level, + rule.is_builtin, + &pattern, + function_id, + ); Ok(()) } @@ -110,6 +314,8 @@ impl<'a> DeclarationCompiler<'a> { &mut self, modes: &[ir::ModeValue], priority: Decimal, + import_level: u32, + is_builtin: bool, pattern: &Pattern, function_id: function::InlineFunctionId, ) { @@ -129,6 +335,8 @@ impl<'a> DeclarationCompiler<'a> { .push(RuleBuilder { priority, declaration_order, + import_level, + is_builtin, pattern: pattern.clone(), function_id, }); @@ -151,14 +359,19 @@ impl<'a> DeclarationCompiler<'a> { } for (mode, mut rule_builders) in self.rule_builders.drain() { - // higher priorities first, same priorities last declaration order wins + // higher priorities first; lower import_level (higher precedence) first; + // same priorities + import_level -> last declaration order wins rule_builders.sort_by_key(|rule_builder| { - (-rule_builder.priority, -rule_builder.declaration_order) + ( + -rule_builder.priority, + rule_builder.import_level as i64, + -rule_builder.declaration_order, + ) }); let rules = rule_builders .drain(..) .map(|rule_builder| rule_builder.rule()) - .collect(); + .collect::>(); let apply_templates_mode_value = match mode { ir::ModeValue::Named(name) => ir::ApplyTemplatesModeValue::Named(name), ir::ModeValue::Unnamed => ir::ApplyTemplatesModeValue::Unnamed, @@ -173,8 +386,85 @@ impl<'a> DeclarationCompiler<'a> { .expect("Mode should have been registered"); self.program .declarations - .mode_lookup - .add_rules(mode_id, rules) + .mode_rules + .insert(mode_id, rules); + } + } + + fn compile_global_params(&mut self, declarations: &ir::Declarations) -> error::SpannedResult<()> { + if declarations.global_params.is_empty() { + return Ok(()); + } + let compiled = { + let mut function_compiler = self.function_compiler(); + let params = declarations + .global_params + .iter() + .map(|param| ir::Param { + name: param.var_name.clone(), + type_: None, + }) + .collect::>(); + let mut compiled = Vec::with_capacity(declarations.global_params.len()); + for global_param in &declarations.global_params { + let default_function = if let Some(default_expr) = &global_param.default_expr { + let function_definition = ir::FunctionDefinition { + params: params.clone(), + return_type: None, + body: Box::new(default_expr.clone()), + }; + Some( + function_compiler.compile_function_id(&function_definition, (0..0).into())?, + ) + } else { + None + }; + compiled.push(GlobalParam { + name: global_param.name.clone(), + required: global_param.required, + overrideable: global_param.overrideable, + default: default_function, + }); + } + compiled + }; + self.program + .declarations + .global_params + .extend(compiled); + Ok(()) + } + + fn build_template_params( + template_params: &[ir::TemplateParam], + function_definition: &ir::FunctionDefinition, + function_compiler: &mut FunctionCompiler<'_>, + ) -> error::SpannedResult> { + if template_params.is_empty() { + return Ok(Vec::new()); + } + let mut params = function_definition.params.clone(); + for param in &mut params { + param.type_ = None; + } + let mut compiled = Vec::with_capacity(template_params.len()); + for template_param in template_params { + let default = if let Some(default_expr) = &template_param.default_expr { + let function_definition = ir::FunctionDefinition { + params: params.clone(), + return_type: None, + body: Box::new(default_expr.clone()), + }; + Some(function_compiler.compile_function_id(&function_definition, (0..0).into())?) + } else { + None + }; + compiled.push(TemplateParam { + name: template_param.name.clone(), + required: template_param.required, + default, + }); } + Ok(compiled) } } diff --git a/xee-ir/src/function_compiler.rs b/xee-ir/src/function_compiler.rs index 70f4a30d2..2ed61a9c9 100644 --- a/xee-ir/src/function_compiler.rs +++ b/xee-ir/src/function_compiler.rs @@ -1,5 +1,7 @@ +use ahash::HashMap; use ibig::{ibig, IBig}; +use xee_interpreter::atomic; use xee_interpreter::error::Error; use xee_interpreter::function::FunctionRule; use xee_interpreter::interpreter::instruction::Instruction; @@ -17,6 +19,8 @@ pub(crate) type Scopes = scope::Scopes; pub struct FunctionCompiler<'a> { pub(crate) scopes: &'a mut Scopes, pub(crate) mode_ids: &'a ModeIds, + pub(crate) user_function_ids: &'a [function::InlineFunctionId], + pub(crate) named_template_ids: &'a HashMap, pub(crate) builder: FunctionBuilder<'a>, } @@ -25,11 +29,15 @@ impl<'a> FunctionCompiler<'a> { builder: FunctionBuilder<'a>, scopes: &'a mut Scopes, mode_ids: &'a ModeIds, + user_function_ids: &'a [function::InlineFunctionId], + named_template_ids: &'a HashMap, ) -> Self { Self { builder, scopes, mode_ids, + user_function_ids, + named_template_ids, } } @@ -53,6 +61,7 @@ impl<'a> FunctionCompiler<'a> { ir::Expr::Step(step) => self.compile_step(step, span), ir::Expr::Deduplicate(expr) => self.compile_deduplicate(expr, span), ir::Expr::If(if_) => self.compile_if(if_, span), + ir::Expr::TryCatch(try_catch) => self.compile_try_catch(try_catch, span), ir::Expr::Map(map) => self.compile_map(map, span), ir::Expr::Filter(filter) => self.compile_filter(filter, span), ir::Expr::Iterate(iterate) => self.compile_iterate(iterate, span), @@ -87,9 +96,17 @@ impl<'a> FunctionCompiler<'a> { self.compile_xml_processing_instruction(processing_instruction, span) } ir::Expr::XmlAppend(xml_append) => self.compile_xml_append(xml_append, span), + ir::Expr::XmlSetType(xml_set_type) => self.compile_xml_set_type(xml_set_type, span), ir::Expr::ApplyTemplates(apply_templates) => { self.compile_apply_templates(apply_templates, span) } + ir::Expr::ApplyImports(apply_imports) => { + self.compile_apply_imports(apply_imports, span) + } + ir::Expr::NextMatch(next_match) => self.compile_next_match(next_match, span), + ir::Expr::CallTemplate(call_template) => { + self.compile_call_template(call_template, span) + } ir::Expr::CopyShallow(copy_shallow) => self.compile_copy_shallow(copy_shallow, span), ir::Expr::CopyDeep(copy_deep) => self.compile_copy_deep(copy_deep, span), } @@ -112,6 +129,15 @@ impl<'a> FunctionCompiler<'a> { ir::Const::Decimal(d) => { self.builder.emit_constant((*d).into(), span); } + ir::Const::UserFunctionReference(index) => { + let function_id = self.user_function_ids.get(*index).ok_or_else(|| { + Error::Unsupported(String::from( + "User function reference out of range", + )) + })?; + self.builder + .emit(Instruction::Closure(function_id.as_u16()), span); + } ir::Const::EmptySequence => self .builder .emit_constant(sequence::Sequence::default(), span), @@ -193,6 +219,31 @@ impl<'a> FunctionCompiler<'a> { Ok(()) } + fn compile_try_catch( + &mut self, + try_catch: &ir::TryCatch, + span: SourceSpan, + ) -> error::SpannedResult<()> { + self.compile_function_definition(&try_catch.try_body, span)?; + for catch in &try_catch.catches { + self.compile_function_definition(&catch.body, span)?; + } + + let entry = xee_interpreter::declaration::TryCatch { + rollback_output: try_catch.rollback_output, + catches: try_catch + .catches + .iter() + .map(|catch| xee_interpreter::declaration::CatchClause { + errors: catch.errors.clone(), + }) + .collect(), + }; + let entry_id = self.builder.add_try_catch(entry); + self.builder.emit(Instruction::TryCatch(entry_id), span); + Ok(()) + } + fn compile_binary( &mut self, binary: &ir::Binary, @@ -330,11 +381,11 @@ impl<'a> FunctionCompiler<'a> { Ok(()) } - pub fn compile_function_id( + fn build_inline_function( &mut self, function_definition: &ir::FunctionDefinition, span: SourceSpan, - ) -> error::SpannedResult { + ) -> error::SpannedResult { let nested_builder = self.builder.builder(); self.scopes.push_scope(); @@ -342,6 +393,8 @@ impl<'a> FunctionCompiler<'a> { builder: nested_builder, scopes: self.scopes, mode_ids: self.mode_ids, + user_function_ids: self.user_function_ids, + named_template_ids: self.named_template_ids, }; for param in &function_definition.params { @@ -354,9 +407,17 @@ impl<'a> FunctionCompiler<'a> { compiler.scopes.pop_scope(); - let function = compiler + Ok(compiler .builder - .finish("inline".to_string(), function_definition, span); + .finish("inline".to_string(), function_definition, span)) + } + + pub fn compile_function_id( + &mut self, + function_definition: &ir::FunctionDefinition, + span: SourceSpan, + ) -> error::SpannedResult { + let function = self.build_inline_function(function_definition, span)?; // now place all captured names on stack, to ensure we have the // closure // in reverse order so we can pop them off in the right order @@ -366,6 +427,17 @@ impl<'a> FunctionCompiler<'a> { Ok(self.builder.add_function(function)) } + pub fn compile_function_id_at( + &mut self, + function_definition: &ir::FunctionDefinition, + function_id: function::InlineFunctionId, + span: SourceSpan, + ) -> error::SpannedResult<()> { + let function = self.build_inline_function(function_definition, span)?; + self.builder.set_function(function_id, function); + Ok(()) + } + pub(crate) fn compile_function_definition( &mut self, function_definition: &ir::FunctionDefinition, @@ -976,6 +1048,17 @@ impl<'a> FunctionCompiler<'a> { Ok(()) } + fn compile_xml_set_type( + &mut self, + xml_set_type: &ir::XmlSetType, + span: SourceSpan, + ) -> error::SpannedResult<()> { + self.compile_atom(&xml_set_type.node)?; + self.builder + .emit(Instruction::XmlSetType(xml_set_type.xs.to_u16()), span); + Ok(()) + } + fn compile_xml_comment( &mut self, comment: &ir::XmlComment, @@ -1003,24 +1086,172 @@ impl<'a> FunctionCompiler<'a> { apply_templates: &ir::ApplyTemplates, span: SourceSpan, ) -> error::SpannedResult<()> { + if !apply_templates.params.is_empty() { + for param in &apply_templates.params { + self.compile_atom(¶m.value)?; + let name_sequence = sequence::Sequence::from(sequence::Item::Atomic( + atomic::Atomic::from(param.name.clone()), + )); + self.builder.emit_constant(name_sequence, span); + } + } self.compile_atom(&apply_templates.select)?; - let mode_id = if matches!( - apply_templates.mode, - ir::ApplyTemplatesModeValue::Named(_) | ir::ApplyTemplatesModeValue::Unnamed - ) { - self.mode_ids.get(&apply_templates.mode) + match apply_templates.mode { + ir::ApplyTemplatesModeValue::Current => { + if apply_templates.params.is_empty() { + self.builder.emit(Instruction::ApplyTemplatesCurrent, span); + } else { + let param_count = apply_templates.params.len(); + if param_count > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many xsl:with-param values".to_string(), + ) + .into()); + } + self.builder.emit( + Instruction::ApplyTemplatesCurrentWithParams(param_count as u16), + span, + ); + } + } + ir::ApplyTemplatesModeValue::Named(_) | ir::ApplyTemplatesModeValue::Unnamed => { + let mode_id = self.mode_ids.get(&apply_templates.mode); + if let Some(mode_id) = mode_id { + if apply_templates.params.is_empty() { + self.builder + .emit(Instruction::ApplyTemplates(mode_id.get() as u16), span); + } else { + let param_count = apply_templates.params.len(); + if param_count > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many xsl:with-param values".to_string(), + ) + .into()); + } + self.builder.emit( + Instruction::ApplyTemplatesWithParams( + mode_id.get() as u16, + param_count as u16, + ), + span, + ); + } + } else { + // the mode was never used by any templates, so compile the empty + // sequence + self.builder + .emit_constant(sequence::Sequence::default(), span); + } + } + } + Ok(()) + } + + fn compile_apply_imports( + &mut self, + apply_imports: &ir::ApplyImports, + span: SourceSpan, + ) -> error::SpannedResult<()> { + if !apply_imports.params.is_empty() { + for param in &apply_imports.params { + self.compile_atom(¶m.value)?; + let name_sequence = sequence::Sequence::from(sequence::Item::Atomic( + atomic::Atomic::from(param.name.clone()), + )); + self.builder.emit_constant(name_sequence, span); + } + } + if apply_imports.params.is_empty() { + self.builder.emit(Instruction::ApplyImports, span); } else { - todo!("#current mode not handled yet") - }; - if let Some(mode_id) = mode_id { - self.builder - .emit(Instruction::ApplyTemplates(mode_id.get() as u16), span); + let param_count = apply_imports.params.len(); + if param_count > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many xsl:with-param values".to_string(), + ) + .into()); + } + self.builder.emit( + Instruction::ApplyImportsWithParams(param_count as u16), + span, + ); + } + Ok(()) + } + + fn compile_next_match( + &mut self, + next_match: &ir::NextMatch, + span: SourceSpan, + ) -> error::SpannedResult<()> { + if !next_match.params.is_empty() { + for param in &next_match.params { + self.compile_atom(¶m.value)?; + let name_sequence = sequence::Sequence::from(sequence::Item::Atomic( + atomic::Atomic::from(param.name.clone()), + )); + self.builder.emit_constant(name_sequence, span); + } + } + if next_match.params.is_empty() { + self.builder.emit(Instruction::ApplyNextMatch, span); } else { - // the mode was never used by any templates, so compile the empty - // sequence + let param_count = next_match.params.len(); + if param_count > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many xsl:with-param values".to_string(), + ) + .into()); + } + self.builder.emit( + Instruction::ApplyNextMatchWithParams(param_count as u16), + span, + ); + } + Ok(()) + } + + fn compile_call_template( + &mut self, + call_template: &ir::CallTemplate, + span: SourceSpan, + ) -> error::SpannedResult<()> { + let function_id = self + .named_template_ids + .get(&call_template.name) + .ok_or_else(|| { + error::Error::Unsupported(String::from("Named template not found")) + })?; + + if !call_template.params.is_empty() { + for param in &call_template.params { + self.compile_atom(¶m.value)?; + let name_sequence = sequence::Sequence::from(sequence::Item::Atomic( + atomic::Atomic::from(param.name.clone()), + )); + self.builder.emit_constant(name_sequence, span); + } + } + + if call_template.params.is_empty() { self.builder - .emit_constant(sequence::Sequence::default(), span); + .emit(Instruction::CallTemplate(function_id.as_u16()), span); + } else { + let param_count = call_template.params.len(); + if param_count > u16::MAX as usize { + return Err(error::Error::Unsupported( + "Too many xsl:with-param values".to_string(), + ) + .into()); + } + self.builder.emit( + Instruction::CallTemplateWithParams( + function_id.as_u16(), + param_count as u16, + ), + span, + ); } Ok(()) } diff --git a/xee-ir/src/ir.rs b/xee-ir/src/ir.rs index d00a2e2ee..6cc8db15d 100644 --- a/xee-ir/src/ir.rs +++ b/xee-ir/src/ir.rs @@ -9,6 +9,7 @@ use rust_decimal::Decimal; pub use xee_interpreter::function::Name; use xee_interpreter::function::{CastType, Signature, StaticFunctionId}; +use xee_interpreter::declaration::CatchError; use xee_interpreter::sequence::SerializationParameters; use xee_interpreter::xml; use xee_schema_type::Xs; @@ -25,6 +26,7 @@ pub enum Expr { Atom(AtomS), Let(Let), If(If), + TryCatch(TryCatch), Binary(Binary), Unary(Unary), FunctionDefinition(FunctionDefinition), @@ -55,7 +57,11 @@ pub enum Expr { XmlComment(XmlComment), XmlProcessingInstruction(XmlProcessingInstruction), XmlAppend(XmlAppend), + XmlSetType(XmlSetType), ApplyTemplates(ApplyTemplates), + ApplyImports(ApplyImports), + NextMatch(NextMatch), + CallTemplate(CallTemplate), CopyShallow(CopyShallow), CopyDeep(CopyDeep), } @@ -74,6 +80,7 @@ pub enum Const { Double(OrderedFloat), Decimal(Decimal), StaticFunctionReference(StaticFunctionId, Option), + UserFunctionReference(usize), // XXX replace this with a sequence constant? useful once we have constant folding EmptySequence, } @@ -99,6 +106,19 @@ pub struct If { pub else_: Box, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TryCatch { + pub try_body: FunctionDefinition, + pub catches: Vec, + pub rollback_output: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CatchClause { + pub errors: Vec, + pub body: FunctionDefinition, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Binary { pub left: AtomS, @@ -329,10 +349,39 @@ pub struct XmlAppend { pub child: AtomS, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct XmlSetType { + pub node: AtomS, + pub xs: Xs, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WithParam { + pub name: xmlname::OwnedName, + pub value: AtomS, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ApplyTemplates { pub mode: ApplyTemplatesModeValue, pub select: AtomS, + pub params: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ApplyImports { + pub params: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NextMatch { + pub params: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CallTemplate { + pub name: xmlname::OwnedName, + pub params: Vec, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -356,8 +405,11 @@ pub struct CopyDeep { pub struct Rule { pub modes: Vec, pub priority: Decimal, + pub import_level: u32, + pub is_builtin: bool, pub pattern: Pattern, pub function_definition: FunctionDefinition, + pub template_params: Vec, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -368,13 +420,27 @@ pub enum ModeValue { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Mode {} +pub struct Mode { + pub on_no_match: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum OnNoMatch { + DeepCopy, + ShallowCopy, + DeepSkip, + ShallowSkip, + TextOnlyCopy, + Fail, +} #[derive(Debug, Clone, PartialEq, Eq)] pub struct Declarations { pub rules: Vec, pub modes: HashMap, Mode>, pub functions: Vec, + pub named_templates: Vec, + pub global_params: Vec, pub main: FunctionDefinition, pub serialization_params: SerializationParameters, } @@ -385,14 +451,42 @@ impl Declarations { rules: Vec::new(), modes: HashMap::new(), functions: Vec::new(), + named_templates: Vec::new(), + global_params: Vec::new(), main, serialization_params: SerializationParameters::new(), } } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlobalParam { + pub name: xmlname::OwnedName, + pub var_name: Name, + pub required: bool, + pub overrideable: bool, + pub default_expr: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TemplateParam { + pub name: xmlname::OwnedName, + pub var_name: Name, + pub required: bool, + pub default_expr: Option, + pub type_: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NamedTemplate { + pub name: xmlname::OwnedName, + pub function_definition: FunctionDefinition, + pub template_params: Vec, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct FunctionBinding { - pub name: Name, + pub name: xmlname::OwnedName, + pub arity: u8, pub main: FunctionDefinition, } diff --git a/xee-ir/tests/test_xml_ir.rs b/xee-ir/tests/test_xml_ir.rs index 1db04dd97..6cd8337ea 100644 --- a/xee-ir/tests/test_xml_ir.rs +++ b/xee-ir/tests/test_xml_ir.rs @@ -1,7 +1,10 @@ -use ahash::HashMapExt; +use ahash::{HashMap, HashMapExt}; use insta::assert_debug_snapshot; -use xee_interpreter::interpreter::{instruction::decode_instructions, Program}; +use xee_interpreter::{ + function, + interpreter::{instruction::decode_instructions, Program}, +}; use xee_ir::{ir, FunctionBuilder, FunctionCompiler, ModeIds, Scopes}; use xee_xpath_ast::span::Spanned; @@ -85,7 +88,16 @@ fn test_generate_element() { let function_builder = FunctionBuilder::new(&mut program); let mut scopes = Scopes::new(); let empty_mode_ids = ModeIds::new(); - let mut compiler = FunctionCompiler::new(function_builder, &mut scopes, &empty_mode_ids); + let empty_user_functions: Vec = Vec::new(); + let empty_named_templates: HashMap = + HashMap::new(); + let mut compiler = FunctionCompiler::new( + function_builder, + &mut scopes, + &empty_mode_ids, + &empty_user_functions, + &empty_named_templates, + ); compiler.compile_expr(&outer_expr).unwrap(); diff --git a/xee-schema-type/src/xs.rs b/xee-schema-type/src/xs.rs index 1238713ec..cb9be13f3 100644 --- a/xee-schema-type/src/xs.rs +++ b/xee-schema-type/src/xs.rs @@ -2,6 +2,7 @@ const XS_NAMESPACE: &str = "http://www.w3.org/2001/XMLSchema"; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] +#[repr(u16)] pub enum Xs { AnyType, AnySimpleType, @@ -87,6 +88,68 @@ impl RustInfo { } impl Xs { + pub const fn to_u16(self) -> u16 { + self as u16 + } + + pub const fn from_u16(value: u16) -> Option { + use Xs::*; + match value { + x if x == AnyType as u16 => Some(AnyType), + x if x == AnySimpleType as u16 => Some(AnySimpleType), + x if x == Error as u16 => Some(Error), + x if x == Untyped as u16 => Some(Untyped), + x if x == AnyAtomicType as u16 => Some(AnyAtomicType), + x if x == Numeric as u16 => Some(Numeric), + x if x == String as u16 => Some(String), + x if x == UntypedAtomic as u16 => Some(UntypedAtomic), + x if x == Boolean as u16 => Some(Boolean), + x if x == Decimal as u16 => Some(Decimal), + x if x == NonPositiveInteger as u16 => Some(NonPositiveInteger), + x if x == NegativeInteger as u16 => Some(NegativeInteger), + x if x == NonNegativeInteger as u16 => Some(NonNegativeInteger), + x if x == PositiveInteger as u16 => Some(PositiveInteger), + x if x == Integer as u16 => Some(Integer), + x if x == Long as u16 => Some(Long), + x if x == Int as u16 => Some(Int), + x if x == Short as u16 => Some(Short), + x if x == Byte as u16 => Some(Byte), + x if x == UnsignedLong as u16 => Some(UnsignedLong), + x if x == UnsignedInt as u16 => Some(UnsignedInt), + x if x == UnsignedShort as u16 => Some(UnsignedShort), + x if x == UnsignedByte as u16 => Some(UnsignedByte), + x if x == Float as u16 => Some(Float), + x if x == Double as u16 => Some(Double), + x if x == QName as u16 => Some(QName), + x if x == Notation as u16 => Some(Notation), + x if x == Duration as u16 => Some(Duration), + x if x == YearMonthDuration as u16 => Some(YearMonthDuration), + x if x == DayTimeDuration as u16 => Some(DayTimeDuration), + x if x == Time as u16 => Some(Time), + x if x == GYearMonth as u16 => Some(GYearMonth), + x if x == GYear as u16 => Some(GYear), + x if x == GMonthDay as u16 => Some(GMonthDay), + x if x == GMonth as u16 => Some(GMonth), + x if x == GDay as u16 => Some(GDay), + x if x == Base64Binary as u16 => Some(Base64Binary), + x if x == HexBinary as u16 => Some(HexBinary), + x if x == AnyURI as u16 => Some(AnyURI), + x if x == DateTime as u16 => Some(DateTime), + x if x == DateTimeStamp as u16 => Some(DateTimeStamp), + x if x == Date as u16 => Some(Date), + x if x == NormalizedString as u16 => Some(NormalizedString), + x if x == Token as u16 => Some(Token), + x if x == Language as u16 => Some(Language), + x if x == NMTOKEN as u16 => Some(NMTOKEN), + x if x == Name as u16 => Some(Name), + x if x == NCName as u16 => Some(NCName), + x if x == ID as u16 => Some(ID), + x if x == IDREF as u16 => Some(IDREF), + x if x == ENTITY as u16 => Some(ENTITY), + _ => None, + } + } + pub fn by_name(namespace: &str, local_name: &str) -> Option { if namespace == XS_NAMESPACE { Xs::by_local_name(local_name) diff --git a/xee-testrunner/README.md b/xee-testrunner/README.md index 0e0179e52..28c4822d9 100644 --- a/xee-testrunner/README.md +++ b/xee-testrunner/README.md @@ -4,8 +4,8 @@ This is a test runner that can run the XPath conformance test suite in the [Xee -project](https://github.com/Paligo/xee). Work on enabling the XSLT conformance -test suite is in progress. +project](https://github.com/Paligo/xee). It can also run the XSLT conformance +test suite, though coverage is still partial. We have added both the [XPath conformance test suite](https://github.com/w3c/qt3tests) and the [XSLT conformance test @@ -27,7 +27,7 @@ To check against regressions, run: cargo run --release -- check ../vendor/xpath-tests/ ``` -or (in the future) +or ``` cargo run --release -- check ../vendor/xslt-tests/ @@ -40,6 +40,10 @@ cargo run --release -- all ../vendor/xpath-tests/ cargo run --release -- all ../vendor/xslt-tests/ ``` +For XSLT, the runner reads `` and `initial-template` from the test +suite metadata. The `enable_assertions` dependency is honored to toggle +`xsl:assert` behavior. + You can run the tests and update the regression filter accordingly: ``` @@ -63,4 +67,4 @@ commandline tool, download a ## Credits This project was made possible by the generous support of -[Paligo](https://paligo.net/). \ No newline at end of file +[Paligo](https://paligo.net/). diff --git a/xee-testrunner/src/dependency.rs b/xee-testrunner/src/dependency.rs index dfd33736c..1a6076685 100644 --- a/xee-testrunner/src/dependency.rs +++ b/xee-testrunner/src/dependency.rs @@ -51,11 +51,19 @@ impl KnownDependencies { impl Dependency { pub(crate) fn load<'a>(queries: &'a Queries) -> Result>> + 'a> { let satisfied_query = queries.option("@satisfied/string()", convert_string)?; + let satisfied_query_dep = satisfied_query.clone(); + let satisfied_query_enable = satisfied_query.clone(); let type_query = queries.one("@type/string()", convert_string)?; let value_query = queries.one("@value/string()", convert_string)?; + let satisfied_query_dep_direct = satisfied_query_dep.clone(); + let satisfied_query_dep_nested = satisfied_query_dep.clone(); + let type_query_direct = type_query.clone(); + let type_query_nested = type_query.clone(); + let value_query_direct = value_query.clone(); + let value_query_nested = value_query.clone(); - let dependency_query = queries.many("dependency", move |session, item| { - let satisfied = satisfied_query.execute(session, item)?; + let dependency_query_direct = queries.many("dependency", move |session, item| { + let satisfied = satisfied_query_dep_direct.execute(session, item)?; let satisfied = if let Some(satisfied) = satisfied { if satisfied == "true" { true @@ -67,9 +75,9 @@ impl Dependency { } else { true }; - let value = value_query.execute(session, item)?; + let value = value_query_direct.execute(session, item)?; let values = value.split(' '); - let type_ = type_query.execute(session, item)?; + let type_ = type_query_direct.execute(session, item)?; Ok(values .map(|value| Dependency { spec: DependencySpec { @@ -80,7 +88,88 @@ impl Dependency { }) .collect::>()) })?; - Ok(dependency_query) + let dependency_query_nested = queries.many("dependencies/dependency", move |session, item| { + let satisfied = satisfied_query_dep_nested.execute(session, item)?; + let satisfied = if let Some(satisfied) = satisfied { + if satisfied == "true" { + true + } else if satisfied == "false" { + false + } else { + panic!("Unexpected satisfied value: {:?}", satisfied); + } + } else { + true + }; + let value = value_query_nested.execute(session, item)?; + let values = value.split(' '); + let type_ = type_query_nested.execute(session, item)?; + Ok(values + .map(|value| Dependency { + spec: DependencySpec { + type_: type_.clone(), + value: value.to_string(), + }, + satisfied, + }) + .collect::>()) + })?; + let satisfied_query_enable_direct = satisfied_query_enable.clone(); + let satisfied_query_enable_nested = satisfied_query_enable.clone(); + let enable_assertions_query_direct = + queries.many("enable_assertions", move |session, item| { + let satisfied = satisfied_query_enable_direct.execute(session, item)?; + let satisfied = if let Some(satisfied) = satisfied { + if satisfied == "true" { + true + } else if satisfied == "false" { + false + } else { + panic!("Unexpected satisfied value: {:?}", satisfied); + } + } else { + true + }; + Ok(vec![Dependency { + spec: DependencySpec { + type_: "feature".to_string(), + value: "enable_assertions".to_string(), + }, + satisfied, + }]) + })?; + let enable_assertions_query_nested = + queries.many("dependencies/enable_assertions", move |session, item| { + let satisfied = satisfied_query_enable_nested.execute(session, item)?; + let satisfied = if let Some(satisfied) = satisfied { + if satisfied == "true" { + true + } else if satisfied == "false" { + false + } else { + panic!("Unexpected satisfied value: {:?}", satisfied); + } + } else { + true + }; + Ok(vec![Dependency { + spec: DependencySpec { + type_: "feature".to_string(), + value: "enable_assertions".to_string(), + }, + satisfied, + }]) + })?; + Ok(queries.one(".", move |session, item| { + let mut deps = dependency_query_direct.execute(session, item)?; + let mut nested = dependency_query_nested.execute(session, item)?; + deps.append(&mut nested); + let mut enable = enable_assertions_query_direct.execute(session, item)?; + let mut enable_nested = enable_assertions_query_nested.execute(session, item)?; + enable.append(&mut enable_nested); + deps.append(&mut enable); + Ok(deps) + })?) } } @@ -110,13 +199,24 @@ impl Dependencies { pub(crate) fn is_feature_supported(&self, known_dependencies: &KnownDependencies) -> bool { for dependency in &self.dependencies { // if a listed feature dependency is not supported, we don't support this - if dependency.spec.type_ == "feature" && !known_dependencies.is_supported(dependency) { + if dependency.spec.type_ == "feature" + && dependency.spec.value != "enable_assertions" + && !known_dependencies.is_supported(dependency) + { return false; } } true } + pub(crate) fn is_feature_disabled(&self, value: &str) -> bool { + self.dependencies.iter().any(|dependency| { + dependency.spec.type_ == "feature" + && dependency.spec.value == value + && !dependency.satisfied + }) + } + // the XML version is supported if the xml-version is the same pub(crate) fn is_xml_version_supported(&self, known_dependencies: &KnownDependencies) -> bool { for dependency in &self.dependencies { diff --git a/xee-testrunner/src/environment/xslt.rs b/xee-testrunner/src/environment/xslt.rs index abd1d05b0..306210726 100644 --- a/xee-testrunner/src/environment/xslt.rs +++ b/xee-testrunner/src/environment/xslt.rs @@ -1,7 +1,7 @@ use anyhow::Result; use xee_xpath::{Queries, Query}; -use xee_xpath_load::ContextLoadable; +use xee_xpath_load::{convert_string, ContextLoadable}; use crate::catalog::LoadContext; @@ -14,7 +14,7 @@ pub(crate) struct Package { #[derive(Debug, Clone)] pub(crate) struct Stylesheet { - // TODO + pub(crate) path: Option, } #[derive(Debug, Clone)] @@ -53,12 +53,16 @@ impl Environment for XsltEnvironmentSpec { fn load(queries: &Queries, context: &LoadContext) -> Result> { let environment_spec_query = EnvironmentSpec::load_with_context(queries, context)?; + let file_query = queries.option("@file/string()", convert_string)?; + let stylesheets_query = queries.many("stylesheet", move |documents, item| { + let file = file_query.execute(documents, item)?; + Ok(Stylesheet { path: file }) + })?; let xslt_environment_spec_query = queries.one(".", move |session, item| { Ok(XsltEnvironmentSpec { environment_spec: environment_spec_query.execute(session, item)?, - // TODO packages: vec![], - stylesheets: vec![], + stylesheets: stylesheets_query.execute(session, item)?, outputs: vec![], }) })?; diff --git a/xee-testrunner/src/testcase/assert.rs b/xee-testrunner/src/testcase/assert.rs index 2b6999da7..0c8f9a2f3 100644 --- a/xee-testrunner/src/testcase/assert.rs +++ b/xee-testrunner/src/testcase/assert.rs @@ -496,7 +496,9 @@ impl Assertable for AssertType { documents: &mut Documents, sequence: &Sequence, ) -> TestOutcome { - let matches = sequence.matches_type(&self.0, documents.xot(), &|function| { + let type_table = documents.type_table(); + let type_table = type_table.borrow(); + let matches = sequence.matches_type(&self.0, documents.xot(), &type_table, &|function| { context.function_info(function).signature() }); match matches { @@ -633,7 +635,12 @@ pub struct AssertError(String); impl AssertError { pub(crate) fn new(code: String) -> Self { - Self(code) + let local = if let Some(rest) = code.strip_prefix("Q{") { + rest.split_once('}').map(|(_, local)| local).unwrap_or(&code) + } else { + code.rsplit_once(':').map(|(_, local)| local).unwrap_or(&code) + }; + Self(local.to_string()) } pub(crate) fn assert_error(&self, error: &error::ErrorValue) -> TestOutcome { @@ -1071,9 +1078,33 @@ fn run_xpath_with_result( let q = queries.sequence_with_context(expr, static_context)?; let variables = AHashMap::from([(name, sequence.clone())]); + let context_item = sequence.iter().next(); + let context_item = match context_item { + Some(Item::Node(node)) => { + let xot = documents.xot_mut(); + let root = xot.root(node); + let item_node = match xot.value(root) { + xot::Value::Document => root, + _ => { + let doc = xot.new_document(); + let cloned = xot.clone_node(node); + xot.append(doc, cloned).unwrap(); + doc + } + }; + Some(Item::Node(item_node)) + } + Some(item) => Some(item), + None => None, + }; + let type_table = documents.type_table().clone(); q.execute_build_context(documents, |build| { build.variables(variables); + build.type_table(type_table.clone()); + if let Some(item) = context_item { + build.context_item(item); + } }) } diff --git a/xee-testrunner/src/testcase/xslt.rs b/xee-testrunner/src/testcase/xslt.rs index e2adf3ef9..c05c6f3d5 100644 --- a/xee-testrunner/src/testcase/xslt.rs +++ b/xee-testrunner/src/testcase/xslt.rs @@ -2,10 +2,11 @@ use std::path::PathBuf; use anyhow::Result; use iri_string::types::IriAbsoluteString; +use xot::xmlname::OwnedName as Name; use xee_xpath::{ context::{self, StaticContextBuilder}, - Queries, Query, + Documents, Queries, Query, }; use xee_xpath_load::{convert_string, ContextLoadable}; @@ -17,6 +18,7 @@ use crate::{ }; use super::{ + assert::TestCaseResult, core::{Runnable, TestCase}, outcome::TestOutcome, }; @@ -33,13 +35,21 @@ impl XsltTestCase {} pub(crate) struct XsltTest { pub(crate) base_dir: PathBuf, pub(crate) stylesheets: Vec, + pub(crate) params: Vec, + pub(crate) initial_template: Option, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct Stylesheet { pub(crate) path: Option, } +#[derive(Debug, Clone)] +pub(crate) struct TestParam { + pub(crate) name: Name, + pub(crate) select: String, +} + impl Runnable for XsltTestCase { fn test_case(&self) -> &TestCase { &self.test_case @@ -51,11 +61,33 @@ impl Runnable for XsltTestCase { catalog: &Catalog, test_set: &TestSet, ) -> TestOutcome { - // TODO take the first stylesheet for now - if self.test.stylesheets.is_empty() { - return TestOutcome::EnvironmentError("No stylesheet found".to_string()); - } - let stylesheet = &self.test.stylesheets[0]; + let stylesheet = if self.test.stylesheets.is_empty() { + let environments = match self + .test_case + .environments(catalog, test_set) + .collect::, crate::error::Error>>() + { + Ok(environments) => environments, + Err(error) => { + return TestOutcome::EnvironmentError(format!( + "Error loading environments: {}", + error + )) + } + }; + let environment_stylesheet = environments + .iter() + .find_map(|environment| environment.stylesheets.first()) + .and_then(|stylesheet| stylesheet.path.clone()); + match environment_stylesheet { + Some(path) => Stylesheet { path: Some(path) }, + None => { + return TestOutcome::EnvironmentError("No stylesheet found".to_string()); + } + } + } else { + self.test.stylesheets[0].clone() + }; // construct full path let path = self.test.base_dir.join(stylesheet.path.as_ref().unwrap()); // load xml text from file @@ -70,21 +102,6 @@ impl Runnable for XsltTestCase { )) } }; - let static_context_builder = StaticContextBuilder::default(); - let static_context = static_context_builder.build(); - let program = xee_xslt_compiler::parse(static_context, &xslt); - let program = match program { - Ok(program) => program, - Err(error) => { - return TestOutcome::EnvironmentError(format!( - "Error parsing stylesheet: {}", - error - )) - } - }; - - // let root = run_context.documents.xot().parse(xml).unwrap(); - // get static base URI: todo refactor out into its own function let static_base_uri = self.test_case.static_base_uri(catalog, test_set); let static_base_uri = match static_base_uri { @@ -105,6 +122,66 @@ impl Runnable for XsltTestCase { Some(test_set.file_uri()) }; + let mut static_context_builder = StaticContextBuilder::default(); + let assertions_enabled = !self + .test_case + .dependencies + .is_feature_disabled("enable_assertions") + && !test_set + .dependencies + .is_feature_disabled("enable_assertions"); + static_context_builder.assertions_enabled(assertions_enabled); + let variables = + self.test_case + .variables(run_context, catalog, test_set, static_base_uri.as_deref()); + let mut variables = match variables { + Ok(variables) => variables, + Err(error) => return TestOutcome::EnvironmentError(error.to_string()), + }; + for param in &self.test.params { + let queries = Queries::default(); + let query = match queries.sequence(¶m.select) { + Ok(query) => query, + Err(error) => { + return TestOutcome::EnvironmentError(format!( + "param: select xpath parse failed: {}", + error + )) + } + }; + let mut documents = Documents::new(); + let dynamic_context_builder = query.dynamic_context_builder(&documents); + let dynamic_context = dynamic_context_builder.build(); + let result = match query.execute_with_context(&mut documents, &dynamic_context) { + Ok(result) => result, + Err(error) => { + return TestOutcome::EnvironmentError(format!( + "param: select xpath eval failed: {}", + error + )) + } + }; + variables.insert(param.name.clone(), result); + } + let variable_names: Vec<_> = variables.keys().cloned().collect(); + static_context_builder.variable_names(variable_names); + let static_context = static_context_builder.build(); + let program = xee_xslt_compiler::parse_with_base(static_context, &xslt, Some(&path)); + let program = match program { + Ok(program) => program, + Err(error) => { + return match &self.test_case.result { + TestCaseResult::AssertError(assert_error) => { + assert_error.assert_error(&error.error) + } + TestCaseResult::AnyOf(any_of) => any_of.assert_error(&error.error), + _ => TestOutcome::CompilationError(error.error), + } + } + }; + + // let root = run_context.documents.xot().parse(xml).unwrap(); + // load all the sources // this makes the sources available on the appropriate URLs let r = @@ -131,11 +208,20 @@ impl Runnable for XsltTestCase { builder.context_item(context_item); } builder.documents(run_context.documents.documents().clone()); - // builder.variables(variables.clone()); + builder.type_table(run_context.documents.type_table().clone()); + builder.variables(variables.clone()); builder.current_datetime(chrono::offset::Utc::now().into()); let context = builder.build(); let runnable = program.runnable(&context); - let result = runnable.many(run_context.documents.xot_mut()); + let result = if let Some(initial_template) = &self.test.initial_template { + runnable.call_named_template( + run_context.documents.xot_mut(), + initial_template, + None, + ) + } else { + runnable.many(run_context.documents.xot_mut()) + }; self.test_case.result.assert_result( &context, @@ -162,16 +248,33 @@ impl ContextLoadable for XsltTestCase { let file = file_query.execute(documents, item)?; Ok(Stylesheet { path: file }) })?; + let param_name_query = queries.one("@name/string()", convert_string)?; + let param_select_query = queries.one("@select/string()", convert_string)?; + let params_query = queries.many("param", move |documents, item| { + let name = param_name_query.execute(documents, item)?; + let select = param_select_query.execute(documents, item)?; + Ok(TestParam { + name: Name::name(&name), + select, + }) + })?; + let initial_template_query = + queries.option("initial-template/@name/string()", convert_string)?; let xslt_test_query = queries.one("test", move |documents, item| { // the base dir is the same as the test set path, but // without the filename let base_dir = context.path.parent().unwrap(); let stylesheets = stylesheets_query.execute(documents, item)?; + let params = params_query.execute(documents, item)?; + let initial_template = initial_template_query.execute(documents, item)?; + let initial_template = initial_template.map(|name| Name::name(&name)); Ok(XsltTest { stylesheets, + params, base_dir: base_dir.to_path_buf(), + initial_template, }) })?; let test_case_query = TestCase::load_with_context(queries, context)?; diff --git a/xee-xpath-ast/src/parser/axis_node_test.rs b/xee-xpath-ast/src/parser/axis_node_test.rs index 4572ab4fc..92e5748a0 100644 --- a/xee-xpath-ast/src/parser/axis_node_test.rs +++ b/xee-xpath-ast/src/parser/axis_node_test.rs @@ -13,8 +13,6 @@ pub(crate) struct ParserAxisNodeTestOutput<'a, I> where I: ValueInput<'a, Token = Token<'a>, Span = Span>, { - pub(crate) node_test: BoxedParser<'a, I, ast::NodeTest>, - pub(crate) abbrev_forward_step: BoxedParser<'a, I, (ast::Axis, ast::NodeTest)>, pub(crate) axis_node_test: BoxedParser<'a, I, (ast::Axis, ast::NodeTest)>, } @@ -214,10 +212,7 @@ where let axis_node_test = reverse_step.or(forward_step).boxed(); - let node_test = node_test_element_name.or(node_test_attribute_name).boxed(); ParserAxisNodeTestOutput { - node_test, - abbrev_forward_step, axis_node_test, } } diff --git a/xee-xpath-ast/src/parser/mod.rs b/xee-xpath-ast/src/parser/mod.rs index b65690732..8448fa0da 100644 --- a/xee-xpath-ast/src/parser/mod.rs +++ b/xee-xpath-ast/src/parser/mod.rs @@ -57,7 +57,11 @@ impl ast::XPath { namespaces: &'a Namespaces, variable_names: &'a VariableNames, ) -> Result { - let mut xpath = parse(parser().xpath, tokens(input), Cow::Borrowed(namespaces))?; + let mut xpath = parse( + parser().xpath, + tokens(input), + Cow::Borrowed(namespaces), + )?; // rename all variables to unique names unique_names(&mut xpath, variable_names); Ok(xpath) @@ -90,41 +94,67 @@ impl ast::XPath { impl ast::ExprSingle { pub fn parse(src: &str) -> Result { let namespaces = Namespaces::default(); - parse(parser().expr_single, tokens(src), Cow::Owned(namespaces)) + let parsed = parse( + parser().expr_single, + tokens(src), + Cow::Owned(namespaces), + )?; + Ok(parsed) } } impl ast::Signature { pub fn parse<'a>(input: &'a str, namespaces: &'a Namespaces) -> Result { - parse(parser().signature, tokens(input), Cow::Borrowed(namespaces)) + let parsed = parse( + parser().signature, + tokens(input), + Cow::Borrowed(namespaces), + )?; + Ok(parsed) } } pub fn parse_kind_test(src: &str) -> Result { let namespaces = Namespaces::default(); - parse(parser().kind_test, tokens(src), Cow::Owned(namespaces)) + let parsed = parse( + parser().kind_test, + tokens(src), + Cow::Owned(namespaces), + )?; + Ok(parsed) } pub fn parse_sequence_type<'a>( input: &'a str, namespaces: &'a Namespaces, ) -> Result { - parse( + let parsed = parse( parser().sequence_type, tokens(input), Cow::Borrowed(namespaces), - ) + )?; + Ok(parsed) } pub fn parse_item_type<'a>( input: &'a str, namespaces: &'a Namespaces, ) -> Result { - parse(parser().item_type, tokens(input), Cow::Borrowed(namespaces)) + let parsed = parse( + parser().item_type, + tokens(input), + Cow::Borrowed(namespaces), + )?; + Ok(parsed) } pub fn parse_name<'a>(src: &'a str, namespaces: &'a Namespaces) -> Result { - parse(parser().name, tokens(src), Cow::Borrowed(namespaces)) + let parsed = parse( + parser().name, + tokens(src), + Cow::Borrowed(namespaces), + )?; + Ok(parsed) } #[cfg(test)] @@ -132,12 +162,16 @@ mod tests { use super::*; - use ahash::HashSetExt; use insta::assert_ron_snapshot; fn parse_xpath_simple(src: &str) -> Result { let namespaces = Namespaces::default(); - parse(parser().xpath, tokens(src), Cow::Owned(namespaces)) + let parsed = parse( + parser().xpath, + tokens(src), + Cow::Owned(namespaces), + )?; + Ok(parsed) } fn parse_xpath_simple_element_ns(src: &str) -> Result { @@ -146,7 +180,12 @@ mod tests { "http://example.com".to_string(), "".to_string(), ); - parse(parser().xpath, tokens(src), Cow::Owned(namespaces)) + let parsed = parse( + parser().xpath, + tokens(src), + Cow::Owned(namespaces), + )?; + Ok(parsed) } #[test] @@ -862,7 +901,12 @@ mod tests { fn test_xpath_parse_value_template() { let namespaces = Namespaces::default(); let xpath = - ast::XPath::parse_value_template("1 + 2}", &namespaces, &VariableNames::new()).unwrap(); + ast::XPath::parse_value_template( + "1 + 2}", + &namespaces, + &VariableNames::default(), + ) + .unwrap(); assert_eq!(xpath.0.span, Span::new(0, 5)); assert_ron_snapshot!(xpath); } @@ -871,7 +915,11 @@ mod tests { fn test_xpath_parse_value_template_with_leftover() { let namespaces = Namespaces::default(); let xpath = - ast::XPath::parse_value_template("1 + 2}foo", &namespaces, &VariableNames::new()) + ast::XPath::parse_value_template( + "1 + 2}foo", + &namespaces, + &VariableNames::default(), + ) .unwrap(); assert_eq!(xpath.0.span, Span::new(0, 5)); assert_ron_snapshot!(xpath); @@ -881,7 +929,12 @@ mod tests { fn test_xpath_parse_value_template_a_with_leftover() { let namespaces = Namespaces::default(); let xpath = - ast::XPath::parse_value_template("a}foo", &namespaces, &VariableNames::new()).unwrap(); + ast::XPath::parse_value_template( + "a}foo", + &namespaces, + &VariableNames::default(), + ) + .unwrap(); assert_eq!(xpath.0.span, Span::new(0, 1)); assert_ron_snapshot!(xpath); } @@ -890,7 +943,11 @@ mod tests { fn test_xpath_parse_value_template_with_second_value_following() { let namespaces = Namespaces::default(); let xpath = - ast::XPath::parse_value_template("a}foo{b}!", &namespaces, &VariableNames::new()) + ast::XPath::parse_value_template( + "a}foo{b}!", + &namespaces, + &VariableNames::default(), + ) .unwrap(); assert_eq!(xpath.0.span, Span::new(0, 1)); assert_ron_snapshot!(xpath); diff --git a/xee-xpath-ast/src/parser/parser_core.rs b/xee-xpath-ast/src/parser/parser_core.rs index 329e35048..69ddcf15e 100644 --- a/xee-xpath-ast/src/parser/parser_core.rs +++ b/xee-xpath-ast/src/parser/parser_core.rs @@ -23,7 +23,6 @@ where { pub(crate) name: BoxedParser<'a, I, ast::NameS>, pub(crate) expr_single: BoxedParser<'a, I, ast::ExprSingleS>, - pub(crate) expr_single_core: BoxedParser<'a, I, ast::ExprSingleS>, pub(crate) signature: BoxedParser<'a, I, ast::Signature>, pub(crate) item_type: BoxedParser<'a, I, ast::ItemType>, pub(crate) sequence_type: BoxedParser<'a, I, ast::SequenceType>, @@ -826,7 +825,6 @@ where .boxed(); let name = eqname.clone().then_ignore(end()).boxed(); - let expr_single_core = expr_single.clone(); let expr_single = expr_single.then_ignore(end()).boxed(); let xpath = expr_ .clone() @@ -854,7 +852,6 @@ where ParserOutput { name, expr_single, - expr_single_core, xpath, xpath_right_brace, signature, diff --git a/xee-xpath-ast/src/parser/pattern.rs b/xee-xpath-ast/src/parser/pattern.rs index bcd3b368c..b4a06f9d7 100644 --- a/xee-xpath-ast/src/parser/pattern.rs +++ b/xee-xpath-ast/src/parser/pattern.rs @@ -1,345 +1,507 @@ -use chumsky::{input::ValueInput, prelude::*}; -use std::borrow::Cow; use xot::xmlname::NameStrInfo; -use xee_xpath_lexer::Token; - use crate::ast::Span; -use crate::{ast, WithSpan, FN_NAMESPACE}; +use crate::{ast, FN_NAMESPACE}; use crate::{pattern, Namespaces, ParserError, VariableNames}; -use super::axis_node_test::parser_axis_node_test; -use super::name::parser_name; -use super::parser_core::parser as xpath_parser; -use super::primary::parser_primary; -use super::{parse, tokens}; - -use super::types::BoxedParser; - -#[derive(Clone)] -pub(crate) struct PatternParserOutput<'a, I> -where - I: ValueInput<'a, Token = Token<'a>, Span = Span>, -{ - pub(crate) pattern: BoxedParser<'a, I, pattern::Pattern>, -} - -pub(crate) fn parser<'a, I>() -> PatternParserOutput<'a, I> -where - I: ValueInput<'a, Token = Token<'a>, Span = Span>, -{ - let xpath_parser_output = xpath_parser(); - let expr_single = xpath_parser_output.expr_single_core; - let name_output = parser_name(); - let name = name_output.eqname; - let parser_primary_output = parser_primary(name.clone()); - let literal = parser_primary_output.literal; - let var_ref = parser_primary_output.var_ref; - let parser_axis_node_test_output = - parser_axis_node_test(name.clone(), xpath_parser_output.kind_test); - let node_test = parser_axis_node_test_output.node_test; - let abbrev_forward_step = parser_axis_node_test_output.abbrev_forward_step; - - // HACK: a bit of repetition here to produce predicate_list, as getting it out - // of the xpath parser seems to lead to recursive parser errors - let expr = expr_single - .clone() - .separated_by(just(Token::Comma)) - .at_least(1) - .collect::>() - .map_with(|exprs, extra| ast::Expr(exprs).with_span(extra.span())) - .boxed(); - let predicate = expr - .clone() - .delimited_by(just(Token::LeftBracket), just(Token::RightBracket)) - .boxed(); - let predicate_list = predicate.repeated().collect::>().boxed(); - - let predicate_pattern = (just(Token::Dot).ignore_then(predicate_list.clone())) - .map(|predicates| pattern::PredicatePattern { predicates }) - .boxed(); - - let outer_function_name = name.try_map(|name, span| { - let name = name.value; - if name.namespace() == FN_NAMESPACE || name.namespace().is_empty() { - { - match name.local_name() { - "doc" => Ok(pattern::OuterFunctionName::Doc), - "id" => Ok(pattern::OuterFunctionName::Id), - "element-with-id" => Ok(pattern::OuterFunctionName::ElementWithId), - "key" => Ok(pattern::OuterFunctionName::Key), - "root" => Ok(pattern::OuterFunctionName::Root), - _ => Err(ParserError::IllegalFunctionInPattern { name, span }), - } +struct RootStep { + root: pattern::RootExpr, + predicates: Vec, +} + +fn unsupported_pattern(span: Span) -> ParserError { + ParserError::ExpectedFound { span } +} + +fn span_is_empty(span: Span) -> bool { + span.start == span.end +} + +fn span_len(span: Span) -> usize { + span.end.saturating_sub(span.start) +} + +fn expr_to_pattern( + expr: &ast::Expr, + span: Span, +) -> Result, ParserError> { + if expr.0.len() != 1 { + return Err(unsupported_pattern(span)); + } + expr_single_to_pattern(&expr.0[0]) +} + +fn expr_to_expr_pattern( + expr: &ast::Expr, + span: Span, +) -> Result, ParserError> { + if expr.0.len() != 1 { + return Err(unsupported_pattern(span)); + } + expr_single_to_expr_pattern(&expr.0[0]) +} + +fn expr_single_to_pattern( + expr_single: &ast::ExprSingleS, +) -> Result, ParserError> { + match &expr_single.value { + ast::ExprSingle::Path(path_expr) => { + if let Some(predicates) = context_item_predicates(path_expr)? { + return Ok(pattern::Pattern::Predicate(pattern::PredicatePattern { + predicates, + })); } - } else { - Err(ParserError::IllegalFunctionInPattern { name, span }) + Ok(pattern::Pattern::Expr(pattern::ExprPattern::Path( + path_expr_to_pattern(path_expr)?, + ))) } - }); + ast::ExprSingle::Binary(binary_expr) => Ok(pattern::Pattern::Expr(convert_binary_expr( + binary_expr, + expr_single.span, + )?)), + _ => Err(unsupported_pattern(expr_single.span)), + } +} - let argument = var_ref - .clone() - .map(|var_ref| { - if let ast::PrimaryExpr::VarRef(name) = var_ref.value { - pattern::Argument::VarRef(name) - } else { - unreachable!() - } - }) - .or(literal.map(|literal| { - if let ast::PrimaryExpr::Literal(literal) = literal.value { - pattern::Argument::Literal(literal) - } else { - unreachable!() - } - })); +fn expr_single_to_expr_pattern( + expr_single: &ast::ExprSingleS, +) -> Result, ParserError> { + match &expr_single.value { + ast::ExprSingle::Path(path_expr) => path_expr_to_expr_pattern(path_expr), + ast::ExprSingle::Binary(binary_expr) => { + convert_binary_expr(binary_expr, expr_single.span) + } + _ => Err(unsupported_pattern(expr_single.span)), + } +} - let argument_list = (argument.separated_by(just(Token::Comma))) - .at_least(1) - .collect::>() - .delimited_by(just(Token::LeftParen), just(Token::RightParen)) - .boxed(); +fn convert_binary_expr( + binary_expr: &ast::BinaryExpr, + span: Span, +) -> Result, ParserError> { + let operator = match binary_expr.operator { + ast::BinaryOperator::Union => pattern::Operator::Union, + ast::BinaryOperator::Intersect => pattern::Operator::Intersect, + ast::BinaryOperator::Except => pattern::Operator::Except, + _ => return Err(unsupported_pattern(span)), + }; + + let left = path_expr_to_expr_pattern(&binary_expr.left)?; + let right = path_expr_to_expr_pattern(&binary_expr.right)?; + Ok(pattern::ExprPattern::BinaryExpr(pattern::BinaryExpr { + operator, + left: Box::new(left), + right: Box::new(right), + })) +} - let function_call = outer_function_name.then(argument_list).boxed(); +fn path_expr_to_expr_pattern( + path_expr: &ast::PathExpr, +) -> Result, ParserError> { + if context_item_predicates(path_expr)?.is_some() { + let span = path_expr + .steps + .first() + .map(|step| step.span) + .unwrap_or_else(|| Span::new(0, 0)); + return Err(unsupported_pattern(span)); + } + Ok(pattern::ExprPattern::Path(path_expr_to_pattern( + path_expr, + )?)) +} - let rooted_var_ref = var_ref.map(|var_ref| { - if let ast::PrimaryExpr::VarRef(name) = var_ref.value { - pattern::RootExpr::VarRef(name) - } else { - unreachable!() - } - }); +fn path_expr_to_pattern( + path_expr: &ast::PathExpr, +) -> Result, ParserError> { + if context_item_predicates(path_expr)?.is_some() { + let span = path_expr + .steps + .first() + .map(|step| step.span) + .unwrap_or_else(|| Span::new(0, 0)); + return Err(unsupported_pattern(span)); + } - let rooted_function_call = function_call - .map(|(name, args)| pattern::RootExpr::FunctionCall(pattern::FunctionCall { name, args })); + let steps = &path_expr.steps; + if steps.is_empty() { + return Err(unsupported_pattern(Span::new(0, 0))); + } - let rooted_path_start = rooted_function_call.or(rooted_var_ref).boxed(); + if let Some((root, start_index)) = implicit_root_info(steps) { + if matches!(root, pattern::PathRoot::AbsoluteDoubleSlash) && steps.len() <= start_index { + return Err(unsupported_pattern(steps[0].span)); + } + let steps = convert_steps(&steps[start_index..])?; + return Ok(pattern::PathExpr { root, steps }); + } - let slash_or_double_slash = just(Token::Slash).or(just(Token::DoubleSlash)); + let (root, start_index) = if let Some(root_step) = root_step_info(steps.get(0))? { + ( + pattern::PathRoot::Rooted { + root: root_step.root, + predicates: root_step.predicates, + }, + 1, + ) + } else { + let steps = convert_steps(steps)?; + return Ok(finalize_relative_path(steps)); + }; + + let steps = convert_steps(&steps[start_index..])?; + Ok(pattern::PathExpr { root, steps }) +} - let expr_pattern = recursive(|expr_pattern| { - let parenthesized_expr = expr_pattern - .delimited_by(just(Token::LeftParen), just(Token::RightParen)) - .boxed(); +fn finalize_relative_path( + steps: Vec>, +) -> pattern::PathExpr { + if steps.len() == 1 { + if let pattern::StepExpr::PostfixExpr(postfix_expr) = &steps[0] { + if postfix_expr.predicates.is_empty() { + if let pattern::ExprPattern::Path(path_expr) = &postfix_expr.expr { + return path_expr.clone(); + } + } + } + } + pattern::PathExpr { + root: pattern::PathRoot::Relative, + steps, + } +} - let postfix_expr = parenthesized_expr.then(predicate_list.clone()).boxed(); - - let forward_axis = (just(Token::Child) - .or(just(Token::Descendant)) - .or(just(Token::Attribute)) - .or(just(Token::Self_)) - .or(just(Token::DescendantOrSelf)) - .or(just(Token::Namespace))) - .then_ignore(just(Token::DoubleColon)) - .map(|token| match token { - Token::Child => pattern::ForwardAxis::Child, - Token::Descendant => pattern::ForwardAxis::Descendant, - Token::Attribute => pattern::ForwardAxis::Attribute, - Token::Self_ => pattern::ForwardAxis::Self_, - Token::DescendantOrSelf => pattern::ForwardAxis::DescendantOrSelf, - Token::Namespace => pattern::ForwardAxis::Namespace, - _ => unreachable!(), - }) - .boxed(); - - let forward_step_axis_node_test = forward_axis.then(node_test); - let forward_step_abbrev = abbrev_forward_step.map(|(axis, node_test)| { - let axis = match axis { - ast::Axis::Attribute => pattern::ForwardAxis::Attribute, - ast::Axis::Child => pattern::ForwardAxis::Child, - _ => unreachable!(), - }; - (axis, node_test) - }); - - let forward_step = forward_step_axis_node_test.or(forward_step_abbrev); - - let axis_step = forward_step.then(predicate_list.clone()); - - let step_expr = postfix_expr - .map(|(expr, predicates)| { - pattern::StepExpr::PostfixExpr(pattern::PostfixExpr { expr, predicates }) - }) - .or(axis_step.map(|((axis, node_test), predicates)| { - pattern::StepExpr::AxisStep(pattern::AxisStep { - forward: axis, - node_test, - predicates, - }) +fn convert_steps( + steps: &[ast::StepExprS], +) -> Result>, ParserError> { + steps + .iter() + .map(convert_step_expr) + .collect::, _>>() +} + +fn convert_step_expr( + step_expr: &ast::StepExprS, +) -> Result, ParserError> { + match &step_expr.value { + ast::StepExpr::AxisStep(axis_step) => Ok(pattern::StepExpr::AxisStep(convert_axis_step( + axis_step, + ))), + ast::StepExpr::PostfixExpr { primary, postfixes } => { + let predicates = collect_predicates(postfixes, step_expr.span)?; + let expr = primary_expr_to_expr_pattern(primary)?; + Ok(pattern::StepExpr::PostfixExpr(pattern::PostfixExpr { + expr, + predicates, })) - .boxed(); + } + ast::StepExpr::PrimaryExpr(primary) => { + let expr = primary_expr_to_expr_pattern(primary)?; + Ok(pattern::StepExpr::PostfixExpr(pattern::PostfixExpr { + expr, + predicates: Vec::new(), + })) + } + } +} - let relative_path_expr = step_expr - .clone() - .then( - (slash_or_double_slash.then(step_expr)) - .repeated() - .collect::>(), - ) - .map(|(first_step, rest_steps)| { - let mut steps = vec![first_step]; - for (token, step) in rest_steps { - match token { - Token::Slash => {} - Token::DoubleSlash => { - let axis_step = pattern::AxisStep { - forward: pattern::ForwardAxis::DescendantOrSelf, - node_test: ast::NodeTest::KindTest(ast::KindTest::Any), - predicates: vec![], - }; - steps.push(pattern::StepExpr::AxisStep(axis_step)); - } - _ => unreachable!(), - } - steps.push(step); - } - steps - }) - .boxed(); +fn primary_expr_to_expr_pattern( + primary: &ast::PrimaryExprS, +) -> Result, ParserError> { + match &primary.value { + ast::PrimaryExpr::Expr(expr_or_empty) => expr_or_empty_to_expr_pattern(expr_or_empty), + _ => Err(unsupported_pattern(primary.span)), + } +} - let rooted_path = rooted_path_start - .then(predicate_list) - .then( - (just(Token::Slash) - .or(just(Token::DoubleSlash)) - .then(relative_path_expr.clone())) - .or_not(), - ) - .map(|((root, predicates), token_relative_steps)| { - let steps = if let Some((token, relative_steps)) = token_relative_steps { - match token { - Token::Slash => relative_steps, - Token::DoubleSlash => { - let axis_step = pattern::AxisStep { - forward: pattern::ForwardAxis::DescendantOrSelf, - node_test: ast::NodeTest::KindTest(ast::KindTest::Any), - predicates: vec![], - }; - let mut steps = vec![pattern::StepExpr::AxisStep(axis_step)]; - steps.extend(relative_steps); - steps - } - _ => unreachable!(), - } - } else { - vec![] - }; - pattern::PathExpr { - root: pattern::PathRoot::Rooted { root, predicates }, - steps, - } - }); - let absolute_slash_path = just(Token::Slash) - .ignore_then(relative_path_expr.clone().or_not()) - .map(|steps| pattern::PathExpr { - root: pattern::PathRoot::AbsoluteSlash, - steps: steps.unwrap_or_default(), - }); - let absolute_double_slash_path = just(Token::DoubleSlash) - .ignore_then(relative_path_expr.clone()) - .map(|steps| pattern::PathExpr { - root: pattern::PathRoot::AbsoluteDoubleSlash, - steps, - }); - let relative_path = relative_path_expr.map(|steps| { - // shortcut to create an absolute path if that's possible. - // The use of parenthesized expr can otherwise turn stuff into - // a postfix expr even though it's actually a simple path expr - if steps.len() == 1 { - if let pattern::StepExpr::PostfixExpr(postfix_expr) = &steps[0] { - if postfix_expr.predicates.is_empty() { - if let pattern::ExprPattern::Path(path_expr) = &postfix_expr.expr { - return path_expr.clone(); - } - } - } +fn expr_or_empty_to_expr_pattern( + expr_or_empty: &ast::ExprOrEmptyS, +) -> Result, ParserError> { + match &expr_or_empty.value { + Some(expr) => expr_to_expr_pattern(expr, expr_or_empty.span), + None => Err(unsupported_pattern(expr_or_empty.span)), + } +} + +fn convert_axis_step(axis_step: &ast::AxisStep) -> pattern::AxisStep { + pattern::AxisStep { + forward: match axis_step.axis { + ast::Axis::Child => pattern::ForwardAxis::Child, + ast::Axis::Descendant => pattern::ForwardAxis::Descendant, + ast::Axis::Attribute => pattern::ForwardAxis::Attribute, + ast::Axis::Self_ => pattern::ForwardAxis::Self_, + ast::Axis::DescendantOrSelf => pattern::ForwardAxis::DescendantOrSelf, + ast::Axis::Namespace => pattern::ForwardAxis::Namespace, + _ => pattern::ForwardAxis::Child, + }, + node_test: axis_step.node_test.clone(), + predicates: axis_step.predicates.clone(), + } +} + +fn collect_predicates( + postfixes: &[ast::Postfix], + span: Span, +) -> Result, ParserError> { + let mut predicates = Vec::new(); + for postfix in postfixes { + match postfix { + ast::Postfix::Predicate(expr) => predicates.push(expr.clone()), + _ => return Err(unsupported_pattern(span)), + } + } + Ok(predicates) +} + +fn context_item_predicates( + path_expr: &ast::PathExpr, +) -> Result>, ParserError> { + if path_expr.steps.len() != 1 { + return Ok(None); + } + let step = &path_expr.steps[0]; + match &step.value { + ast::StepExpr::PrimaryExpr(primary) => { + if matches!(primary.value, ast::PrimaryExpr::ContextItem) { + Ok(Some(Vec::new())) + } else { + Ok(None) } - pattern::PathExpr { - root: pattern::PathRoot::Relative, - steps, + } + ast::StepExpr::PostfixExpr { primary, postfixes } => { + if matches!(primary.value, ast::PrimaryExpr::ContextItem) { + Ok(Some(collect_predicates(postfixes, step.span)?)) + } else { + Ok(None) } - }); + } + _ => Ok(None), + } +} - let path_expr = absolute_slash_path - .or(absolute_double_slash_path) - .or(relative_path) - .or(rooted_path) - .boxed(); +fn root_step_info( + step_expr: Option<&ast::StepExprS>, +) -> Result, ParserError> { + let step_expr = match step_expr { + Some(step_expr) => step_expr, + None => return Ok(None), + }; + match &step_expr.value { + ast::StepExpr::PrimaryExpr(primary) => { + if let Some(root) = root_from_primary(primary)? { + Ok(Some(RootStep { + root, + predicates: Vec::new(), + })) + } else { + Ok(None) + } + } + ast::StepExpr::PostfixExpr { primary, postfixes } => { + if let Some(root) = root_from_primary(primary)? { + let predicates = collect_predicates(postfixes, step_expr.span)?; + Ok(Some(RootStep { root, predicates })) + } else { + Ok(None) + } + } + _ => Ok(None), + } +} - let operator = just(Token::Intersect) - .or(just(Token::Except)) - .or(just(Token::Union)) - .or(just(Token::Pipe)) - .map(|token| match token { - Token::Intersect => pattern::Operator::Intersect, - Token::Except => pattern::Operator::Except, - Token::Union => pattern::Operator::Union, - Token::Pipe => pattern::Operator::Union, - _ => unreachable!(), - }); - - let expr_pattern = (path_expr.clone().map(pattern::ExprPattern::Path)) - .foldl( - operator.then(path_expr.clone()).repeated(), - |left, (operator, right)| { - pattern::ExprPattern::BinaryExpr(pattern::BinaryExpr { - operator, - left: Box::new(left), - right: Box::new(pattern::ExprPattern::Path(right)), - }) - }, - ) - .boxed(); +fn root_from_primary( + primary: &ast::PrimaryExprS, +) -> Result, ParserError> { + match &primary.value { + ast::PrimaryExpr::VarRef(name) => Ok(Some(pattern::RootExpr::VarRef(name.clone()))), + ast::PrimaryExpr::FunctionCall(call) => { + Ok(Some(pattern::RootExpr::FunctionCall(convert_function_call( + call, primary.span, + )?))) + } + _ => Ok(None), + } +} + +fn convert_function_call( + call: &ast::FunctionCall, + span: Span, +) -> Result { + let name = convert_outer_function_name(&call.name)?; + if call.arguments.is_empty() { + return Err(unsupported_pattern(span)); + } + let args = call + .arguments + .iter() + .map(convert_argument) + .collect::, _>>()?; + Ok(pattern::FunctionCall { name, args }) +} - expr_pattern - }) - .boxed(); +fn convert_outer_function_name( + name: &ast::NameS, +) -> Result { + let value = &name.value; + if value.namespace() == FN_NAMESPACE || value.namespace().is_empty() { + match value.local_name() { + "doc" => Ok(pattern::OuterFunctionName::Doc), + "id" => Ok(pattern::OuterFunctionName::Id), + "element-with-id" => Ok(pattern::OuterFunctionName::ElementWithId), + "key" => Ok(pattern::OuterFunctionName::Key), + "root" => Ok(pattern::OuterFunctionName::Root), + _ => Err(ParserError::IllegalFunctionInPattern { + name: value.clone(), + span: name.span, + }), + } + } else { + Err(ParserError::IllegalFunctionInPattern { + name: value.clone(), + span: name.span, + }) + } +} - let predicate_pattern = predicate_pattern - .then_ignore(end()) - .map(pattern::Pattern::Predicate) - .boxed(); +fn convert_argument(expr_single: &ast::ExprSingleS) -> Result { + match &expr_single.value { + ast::ExprSingle::Path(path_expr) => argument_from_path(path_expr, expr_single.span), + _ => Err(unsupported_pattern(expr_single.span)), + } +} - let union_pattern = expr_pattern - .then_ignore(end()) - .map(pattern::Pattern::Expr) - .boxed(); +fn argument_from_path( + path_expr: &ast::PathExpr, + span: Span, +) -> Result { + if path_expr.steps.len() != 1 { + return Err(unsupported_pattern(span)); + } + match &path_expr.steps[0].value { + ast::StepExpr::PrimaryExpr(primary) => match &primary.value { + ast::PrimaryExpr::VarRef(name) => Ok(pattern::Argument::VarRef(name.clone())), + ast::PrimaryExpr::Literal(literal) => { + Ok(pattern::Argument::Literal(literal.clone())) + } + _ => Err(unsupported_pattern(primary.span)), + }, + _ => Err(unsupported_pattern(span)), + } +} - let pattern = predicate_pattern.or(union_pattern).boxed(); +fn implicit_root_info( + steps: &[ast::StepExprS], +) -> Option<(pattern::PathRoot, usize)> { + let (first, rest) = steps.split_first()?; + if !is_implicit_root_step(first) { + return None; + } + if let Some((second, _rest)) = rest.split_first() { + if is_implicit_descendant_step(second, first.span) { + return Some((pattern::PathRoot::AbsoluteDoubleSlash, 2)); + } + } + Some((pattern::PathRoot::AbsoluteSlash, 1)) +} - PatternParserOutput { pattern } +fn is_implicit_root_step(step: &ast::StepExprS) -> bool { + match &step.value { + ast::StepExpr::PrimaryExpr(primary) => match &primary.value { + ast::PrimaryExpr::FunctionCall(call) => { + if !span_is_empty(call.name.span) { + return false; + } + let name = &call.name.value; + if name.namespace() != FN_NAMESPACE || name.local_name() != "root" { + return false; + } + call.arguments.len() == 1 && argument_is_self_node(&call.arguments[0]) + } + _ => false, + }, + _ => false, + } +} + +fn argument_is_self_node(expr_single: &ast::ExprSingleS) -> bool { + match &expr_single.value { + ast::ExprSingle::Path(path_expr) => { + if path_expr.steps.len() != 1 { + return false; + } + match &path_expr.steps[0].value { + ast::StepExpr::AxisStep(axis_step) => { + axis_step.axis == ast::Axis::Self_ + && matches!(axis_step.node_test, ast::NodeTest::KindTest(ast::KindTest::Any)) + && axis_step.predicates.is_empty() + } + _ => false, + } + } + _ => false, + } +} + +fn is_implicit_descendant_step(step: &ast::StepExprS, root_span: Span) -> bool { + if !is_descendant_or_self_any(step) { + return false; + } + if span_is_empty(step.span) { + return true; + } + let step_len = span_len(step.span); + let root_len = span_len(root_span); + step_len == root_len && step_len <= 2 +} + +fn is_descendant_or_self_any(step: &ast::StepExprS) -> bool { + match &step.value { + ast::StepExpr::AxisStep(axis_step) => { + axis_step.axis == ast::Axis::DescendantOrSelf + && matches!(axis_step.node_test, ast::NodeTest::KindTest(ast::KindTest::Any)) + && axis_step.predicates.is_empty() + } + _ => false, + } } impl pattern::Pattern { pub fn parse<'a>( input: &'a str, namespaces: &'a Namespaces, - _variable_names: &'a VariableNames, + variable_names: &'a VariableNames, ) -> Result { - let pattern = parse(parser().pattern, tokens(input), Cow::Borrowed(namespaces))?; - // TODO: do we need to rename variables to unique names? probably - Ok(pattern) + let ast::XPath(expr) = ast::XPath::parse(input, namespaces, variable_names)?; + expr_to_pattern(&expr.value, expr.span) } } #[cfg(test)] mod tests { - use ahash::HashSetExt; + use chumsky::prelude::*; use insta::assert_ron_snapshot; + use std::borrow::Cow; + use xee_xpath_lexer::Token; + + use super::super::axis_node_test::parser_axis_node_test; + use super::super::kind_test::parser_kind_test; + use super::super::name::parser_name; + use super::super::primary::parser_primary; + use super::super::{parse, tokens}; use super::*; #[test] fn test_predicate_pattern_no_predicates() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse(".", &namespaces, &variable_names)); } #[test] fn test_predicate_pattern_single_predicate() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( ".[1]", &namespaces, @@ -347,10 +509,89 @@ mod tests { )); } + #[test] + fn test_predicate_pattern_dot_equals() { + let namespaces = Namespaces::default(); + let variable_names = VariableNames::default(); + assert_ron_snapshot!(pattern::Pattern::parse( + ".[.='10']", + &namespaces, + &variable_names + )); + } + + #[test] + fn test_text_predicate_pattern() { + let namespaces = Namespaces::default(); + let variable_names = VariableNames::default(); + assert_ron_snapshot!(pattern::Pattern::parse( + "text()[.='10']", + &namespaces, + &variable_names + )); + } + + #[test] + fn test_text_pattern() { + let namespaces = Namespaces::default(); + let variable_names = VariableNames::default(); + assert_ron_snapshot!(pattern::Pattern::parse( + "text()", + &namespaces, + &variable_names + )); + } + + #[test] + fn test_text_predicate_numeric_pattern() { + let namespaces = Namespaces::default(); + let variable_names = VariableNames::default(); + assert_ron_snapshot!(pattern::Pattern::parse( + "text()[1]", + &namespaces, + &variable_names + )); + } + + #[test] + fn test_axis_node_test_parse() { + let namespaces = Namespaces::default(); + let parser_name_output = parser_name(); + let name = parser_name_output.eqname; + let ncname = parser_name_output.ncname; + let parser_primary_output = parser_primary(name.clone()); + let string = parser_primary_output.string; + let empty_call = just(Token::LeftParen) + .ignore_then(just(Token::RightParen)) + .boxed(); + let kind_test = parser_kind_test(name.clone(), empty_call, ncname, string).kind_test; + let parser_axis_node_test_output = parser_axis_node_test(name, kind_test); + let axis_node_test = parser_axis_node_test_output + .axis_node_test + .then_ignore(end()) + .boxed(); + assert_ron_snapshot!(parse( + axis_node_test, + tokens("text()"), + Cow::Borrowed(&namespaces), + )); + } + + #[test] + fn test_predicate_expr_parse() { + let namespaces = Namespaces::default(); + let variable_names = VariableNames::default(); + assert_ron_snapshot!(ast::XPath::parse( + ".='10'", + &namespaces, + &variable_names + )); + } + #[test] fn test_expr_pattern() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "$a | $b", &namespaces, @@ -361,7 +602,7 @@ mod tests { #[test] fn test_expr_pattern_rooted_path() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "$a/foo", &namespaces, @@ -372,7 +613,7 @@ mod tests { #[test] fn test_expr_pattern_absolute_slash() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "/foo", &namespaces, @@ -383,7 +624,7 @@ mod tests { #[test] fn test_expr_pattern_absolute_double_slash() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "//foo", &namespaces, @@ -394,28 +635,28 @@ mod tests { #[test] fn test_absolute_slash_without_steps() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse("/", &namespaces, &variable_names)); } #[test] fn test_absolute_slash_without_steps_in_parenthesis() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse("(/)", &namespaces, &variable_names)); } #[test] fn test_expr_pattern_relative() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse("foo", &namespaces, &variable_names)); } #[test] fn test_postfix_expr() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "foo[1]", &namespaces, @@ -423,10 +664,28 @@ mod tests { )); } + #[test] + fn test_nested_predicate_parses() { + let namespaces = Namespaces::default(); + let variable_names = VariableNames::default(); + assert!(pattern::Pattern::parse( + "foo[(bar[2])='this']", + &namespaces, + &variable_names + ) + .is_ok()); + assert!(pattern::Pattern::parse( + "foo[(bar[2][(baz[2])='goodbye'])]", + &namespaces, + &variable_names + ) + .is_ok()); + } + #[test] fn test_union() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "foo | bar", &namespaces, @@ -437,7 +696,7 @@ mod tests { #[test] fn test_intersect() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "foo intersect bar", &namespaces, @@ -448,7 +707,7 @@ mod tests { #[test] fn test_union_with_intersect() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "foo intersect bar | baz", &namespaces, @@ -459,7 +718,7 @@ mod tests { #[test] fn test_union_with_union() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "foo | (bar | baz)", &namespaces, @@ -470,7 +729,7 @@ mod tests { #[test] fn test_intersect_with_union() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); assert_ron_snapshot!(pattern::Pattern::parse( "foo intersect (bar | baz)", &namespaces, @@ -481,7 +740,7 @@ mod tests { #[test] fn test_root_intersect_with_other_path() { let namespaces = Namespaces::default(); - let variable_names = VariableNames::new(); + let variable_names = VariableNames::default(); // have to use bracketrs here, as otherwise 'intersect' is interpreted // as an element name as per xpath rules assert_ron_snapshot!(pattern::Pattern::parse( diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__axis_node_test_parse.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__axis_node_test_parse.snap new file mode 100644 index 000000000..16cfff060 --- /dev/null +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__axis_node_test_parse.snap @@ -0,0 +1,6 @@ +--- +source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 450 +expression: "parse(axis_node_test, tokens(\"text()\"), Cow::Borrowed(&namespaces))" +--- +Ok((Child, KindTest(Text))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__expr_pattern.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__expr_pattern.snap index f938371e0..be30f655f 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__expr_pattern.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__expr_pattern.snap @@ -1,29 +1,38 @@ --- source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 892 expression: "pattern::Pattern::parse(\"$a | $b\", &namespaces, &variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Union, - left: Path(PathExpr( - root: Rooted( - root: VarRef(OwnedName( - local_name_str: "a", - namespace_str: "", - prefix_str: "", +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Union, + left: Path(PathExpr( + root: Rooted( + root: VarRef(OwnedName( + local_name_str: "a", + namespace_str: "", + prefix_str: "", + )), + predicates: [], + ), + steps: [], + )), + right: Path(PathExpr( + root: Rooted( + root: VarRef(OwnedName( + local_name_str: "b", + namespace_str: "", + prefix_str: "", + )), + predicates: [], + ), + steps: [], + )), )), predicates: [], - ), - steps: [], - )), - right: Path(PathExpr( - root: Rooted( - root: VarRef(OwnedName( - local_name_str: "b", - namespace_str: "", - prefix_str: "", - )), - predicates: [], - ), - steps: [], - )), + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect.snap index 05f859f4f..8d8cdf28e 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect.snap @@ -1,35 +1,44 @@ --- source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 997 expression: "pattern::Pattern::parse(\"foo intersect bar\", &namespaces, &variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Intersect, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "foo", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Intersect, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "foo", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "bar", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "bar", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], - )), - ], - )), + predicates: [], + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect_with_union.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect_with_union.snap index 764908774..f605199fb 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect_with_union.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__intersect_with_union.snap @@ -1,60 +1,69 @@ --- source: xee-xpath-ast/src/parser/pattern.rs -expression: "pattern::Pattern::parse(\"foo intersect (bar | baz)\", &namespaces,\n &variable_names)" +assertion_line: 1030 +expression: "pattern::Pattern::parse(\"foo intersect (bar | baz)\", &namespaces,\n&variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Intersect, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "foo", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], - )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - PostfixExpr(PostfixExpr( - expr: BinaryExpr(BinaryExpr( - operator: Union, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "bar", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], - )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "baz", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Intersect, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "foo", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Union, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "bar", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "baz", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), )), - ], - )), + predicates: [], + )), + ], )), - predicates: [], )), - ], - )), + predicates: [], + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__predicate_expr_parse.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__predicate_expr_parse.snap new file mode 100644 index 000000000..2ba1f9d33 --- /dev/null +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__predicate_expr_parse.snap @@ -0,0 +1,26 @@ +--- +source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 358 +expression: "ast::XPath::parse(\".='10'\", &namespaces, &variable_names)" +--- +Ok(XPath(Expr([ + Path(PathExpr( + steps: [ + PrimaryExpr(Expr(Some(Expr([ + Binary(BinaryExpr( + operator: GenEq, + left: PathExpr( + steps: [ + PrimaryExpr(ContextItem), + ], + ), + right: PathExpr( + steps: [ + PrimaryExpr(Literal(String("10"))), + ], + ), + )), + ])))), + ], + )), +]))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__predicate_pattern_dot_equals.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__predicate_pattern_dot_equals.snap new file mode 100644 index 000000000..2b5ce99e7 --- /dev/null +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__predicate_pattern_dot_equals.snap @@ -0,0 +1,30 @@ +--- +source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 347 +expression: "pattern::Pattern::parse(\".[.='10']\", &namespaces, &variable_names)" +--- +Ok(Predicate(PredicatePattern( + predicates: [ + Expr([ + Path(PathExpr( + steps: [ + PrimaryExpr(Expr(Some(Expr([ + Binary(BinaryExpr( + operator: GenEq, + left: PathExpr( + steps: [ + PrimaryExpr(ContextItem), + ], + ), + right: PathExpr( + steps: [ + PrimaryExpr(Literal(String("10"))), + ], + ), + )), + ])))), + ], + )), + ]), + ], +))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__root_intersect_with_other_path.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__root_intersect_with_other_path.snap index 18c72eb08..407c7a1d6 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__root_intersect_with_other_path.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__root_intersect_with_other_path.snap @@ -1,25 +1,34 @@ --- source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 1043 expression: "pattern::Pattern::parse(\"(/) intersect foo\", &namespaces, &variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Intersect, - left: Path(PathExpr( - root: AbsoluteSlash, - steps: [], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "foo", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Intersect, + left: Path(PathExpr( + root: AbsoluteSlash, + steps: [], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "foo", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), )), - ], - )), + predicates: [], + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_pattern.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_pattern.snap new file mode 100644 index 000000000..d9e0f8381 --- /dev/null +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_pattern.snap @@ -0,0 +1,15 @@ +--- +source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 358 +expression: "pattern::Pattern::parse(\"text()\", &namespaces, &variable_names)" +--- +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: KindTest(Text), + predicates: [], + )), + ], +)))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_predicate_numeric_pattern.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_predicate_numeric_pattern.snap new file mode 100644 index 000000000..411475647 --- /dev/null +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_predicate_numeric_pattern.snap @@ -0,0 +1,25 @@ +--- +source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 390 +expression: "pattern::Pattern::parse(\"text()[1]\", &namespaces, &variable_names)" +--- +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: KindTest(Text), + predicates: [ + Expr([ + Path(PathExpr( + steps: [ + PrimaryExpr(Literal(Integer((Positive, [ + 1, + ])))), + ], + )), + ]), + ], + )), + ], +)))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_predicate_pattern.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_predicate_pattern.snap new file mode 100644 index 000000000..713613152 --- /dev/null +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__text_predicate_pattern.snap @@ -0,0 +1,37 @@ +--- +source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 821 +expression: "pattern::Pattern::parse(\"text()[.='10']\", &namespaces, &variable_names)" +--- +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: KindTest(Text), + predicates: [ + Expr([ + Path(PathExpr( + steps: [ + PrimaryExpr(Expr(Some(Expr([ + Binary(BinaryExpr( + operator: GenEq, + left: PathExpr( + steps: [ + PrimaryExpr(ContextItem), + ], + ), + right: PathExpr( + steps: [ + PrimaryExpr(Literal(String("10"))), + ], + ), + )), + ])))), + ], + )), + ]), + ], + )), + ], +)))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union.snap index 8dfb59654..c1c87b040 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union.snap @@ -1,35 +1,44 @@ --- source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 986 expression: "pattern::Pattern::parse(\"foo | bar\", &namespaces, &variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Union, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "foo", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Union, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "foo", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "bar", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "bar", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], - )), - ], - )), + predicates: [], + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_intersect.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_intersect.snap index 9e27fcd67..abb5fdba9 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_intersect.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_intersect.snap @@ -1,52 +1,69 @@ --- source: xee-xpath-ast/src/parser/pattern.rs -expression: "pattern::Pattern::parse(\"foo intersect bar | baz\", &namespaces,\n &variable_names)" +assertion_line: 1008 +expression: "pattern::Pattern::parse(\"foo intersect bar | baz\", &namespaces,\n&variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Union, - left: BinaryExpr(BinaryExpr( - operator: Intersect, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "foo", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Union, + left: Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Intersect, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "foo", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "bar", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + )), + predicates: [], + )), + ], )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "bar", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "baz", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], )), - ], - )), - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "baz", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], )), - ], - )), + predicates: [], + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_union.snap b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_union.snap index 319bfa59b..212d05160 100644 --- a/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_union.snap +++ b/xee-xpath-ast/src/parser/snapshots/xee_xpath_ast__parser__pattern__tests__union_with_union.snap @@ -1,60 +1,69 @@ --- source: xee-xpath-ast/src/parser/pattern.rs +assertion_line: 1019 expression: "pattern::Pattern::parse(\"foo | (bar | baz)\", &namespaces, &variable_names)" --- -Ok(Expr(BinaryExpr(BinaryExpr( - operator: Union, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "foo", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], - )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - PostfixExpr(PostfixExpr( - expr: BinaryExpr(BinaryExpr( - operator: Union, - left: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "bar", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], - )), - ], - )), - right: Path(PathExpr( - root: Relative, - steps: [ - AxisStep(AxisStep( - forward: Child, - node_test: NameTest(Name(OwnedName( - local_name_str: "baz", - namespace_str: "", - prefix_str: "", - ))), - predicates: [], +Ok(Expr(Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Union, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "foo", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + PostfixExpr(PostfixExpr( + expr: BinaryExpr(BinaryExpr( + operator: Union, + left: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "bar", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), + right: Path(PathExpr( + root: Relative, + steps: [ + AxisStep(AxisStep( + forward: Child, + node_test: NameTest(Name(OwnedName( + local_name_str: "baz", + namespace_str: "", + prefix_str: "", + ))), + predicates: [], + )), + ], + )), )), - ], - )), + predicates: [], + )), + ], )), - predicates: [], )), - ], - )), + predicates: [], + )), + ], )))) diff --git a/xee-xpath-ast/src/parser/types.rs b/xee-xpath-ast/src/parser/types.rs index f3e8a773e..bcc424438 100644 --- a/xee-xpath-ast/src/parser/types.rs +++ b/xee-xpath-ast/src/parser/types.rs @@ -5,7 +5,6 @@ use std::borrow::Cow; use crate::error::ParserError; use crate::Namespaces; - pub(crate) struct State<'a> { pub(crate) namespaces: Cow<'a, Namespaces>, } diff --git a/xee-xpath-compiler/src/ast_ir.rs b/xee-xpath-compiler/src/ast_ir.rs index 272b36cdc..c0903419a 100644 --- a/xee-xpath-compiler/src/ast_ir.rs +++ b/xee-xpath-compiler/src/ast_ir.rs @@ -1,3 +1,4 @@ +use ahash::HashMap; use xee_interpreter::{context, error, error::Error, function, xml}; use xee_ir::{ir, ir::AtomS, Binding, Bindings, Variables}; use xee_schema_type::Xs; @@ -8,15 +9,53 @@ use xot::xmlname::NameStrInfo; pub struct IrConverter<'a> { variables: &'a mut Variables, static_context: &'a context::StaticContext, + user_functions: Option, fn_position: ast::Name, fn_last: ast::Name, } +#[derive(Debug, Clone)] +pub struct UserFunctions { + lookup: HashMap<(ast::Name, u8), usize>, + global_param_names: Vec, +} + +impl UserFunctions { + pub fn new( + lookup: HashMap<(ast::Name, u8), usize>, + global_param_names: Vec, + ) -> Self { + Self { + lookup, + global_param_names, + } + } +} + impl<'a> IrConverter<'a> { pub fn new(variables: &'a mut Variables, static_context: &'a context::StaticContext) -> Self { Self { variables, static_context, + user_functions: None, + fn_position: ast::Name::new( + "position".to_string(), + FN_NAMESPACE.to_string(), + String::new(), + ), + fn_last: ast::Name::new("last".to_string(), FN_NAMESPACE.to_string(), String::new()), + } + } + + pub fn new_with_user_functions( + variables: &'a mut Variables, + static_context: &'a context::StaticContext, + user_functions: UserFunctions, + ) -> Self { + Self { + variables, + static_context, + user_functions: Some(user_functions), fn_position: ast::Name::new( "position".to_string(), FN_NAMESPACE.to_string(), @@ -600,17 +639,40 @@ impl<'a> IrConverter<'a> { // advice: format!("Either the function name {:?} does not exist, or you are calling it with the wrong number of arguments ({})", ast.name, arity), let static_function_id = self .static_context - .function_id_by_name(&ast.name.value, arity as u8) - .ok_or(Error::XPST0017.with_ast_span(span))?; + .function_id_by_name(&ast.name.value, arity as u8); // TODO we don't know yet how to get the proper span here let empty_span = (0..0).into(); - let mut static_function_ref_bindings = - self.static_function_ref(static_function_id, empty_span); - let atom = static_function_ref_bindings.atom(); + let mut function_ref_bindings = if let Some(static_function_id) = static_function_id { + self.static_function_ref(static_function_id, empty_span) + } else if let Some(user_functions) = &self.user_functions { + let key = (ast.name.value.clone(), arity as u8); + let index = user_functions + .lookup + .get(&key) + .ok_or(Error::XPST0017.with_ast_span(span))?; + self.user_function_ref(*index, empty_span) + } else { + return Err(Error::XPST0017.with_ast_span(span)); + }; + let atom = function_ref_bindings.atom(); let (arg_bindings, atoms) = self.args(&ast.arguments)?; + let mut atoms = atoms; + if let Some(user_functions) = &self.user_functions { + if user_functions + .lookup + .contains_key(&(ast.name.value.clone(), arity as u8)) + { + atoms.extend( + user_functions + .global_param_names + .iter() + .map(|name| Spanned::new(ir::Atom::Variable(name.clone()), empty_span)), + ); + } + } let expr = ir::Expr::FunctionCall(ir::FunctionCall { atom, args: atoms }); let binding = self.variables.new_binding(expr, span); - Ok(static_function_ref_bindings + Ok(function_ref_bindings .concat(arg_bindings) .bind(binding)) } @@ -623,9 +685,19 @@ impl<'a> IrConverter<'a> { // advice: format!("Either the function name {:?} does not exist, or you are calling it with the wrong number of arguments ({})", ast.name, ast.arity), let static_function_id = self .static_context - .function_id_by_name(&ast.name.value, ast.arity) - .ok_or(Error::XPST0017.with_ast_span(span))?; - Ok(self.static_function_ref(static_function_id, span)) + .function_id_by_name(&ast.name.value, ast.arity); + if let Some(static_function_id) = static_function_id { + Ok(self.static_function_ref(static_function_id, span)) + } else if let Some(user_functions) = &self.user_functions { + let key = (ast.name.value.clone(), ast.arity); + let index = user_functions + .lookup + .get(&key) + .ok_or(Error::XPST0017.with_ast_span(span))?; + Ok(self.user_function_ref(*index, span)) + } else { + Err(Error::XPST0017.with_ast_span(span)) + } } fn static_function_ref( @@ -642,6 +714,13 @@ impl<'a> IrConverter<'a> { Bindings::new(binding) } + fn user_function_ref(&mut self, index: usize, span: Span) -> Bindings { + let atom = ir::Atom::Const(ir::Const::UserFunctionReference(index)); + let expr = ir::Expr::Atom(Spanned::new(atom, span)); + let binding = self.variables.new_binding(expr, span); + Bindings::new(binding) + } + fn args(&mut self, args: &[ast::ExprSingleS]) -> error::SpannedResult<(Bindings, Vec)> { if args.is_empty() { return Ok((Bindings::empty(), vec![])); diff --git a/xee-xpath-compiler/src/lib.rs b/xee-xpath-compiler/src/lib.rs index 54cd3d8cf..2cf5d27b1 100644 --- a/xee-xpath-compiler/src/lib.rs +++ b/xee-xpath-compiler/src/lib.rs @@ -9,5 +9,5 @@ pub use xee_xpath_ast::{Namespaces, VariableNames}; pub use xee_interpreter::interpreter::Runnable; pub use xee_interpreter::{atomic, context, error, interpreter, occurrence, sequence, string, xml}; -pub use crate::ast_ir::IrConverter; +pub use crate::ast_ir::{IrConverter, UserFunctions}; pub use crate::compile::{compile, parse}; diff --git a/xee-xpath-lexer/tests/test_lexer.rs b/xee-xpath-lexer/tests/test_lexer.rs index a0c5a358a..f37755c0d 100644 --- a/xee-xpath-lexer/tests/test_lexer.rs +++ b/xee-xpath-lexer/tests/test_lexer.rs @@ -531,3 +531,15 @@ fn test_function_name_026() { assert_eq!(lex.next(), Some((Token::RightParen, (26..27)))); assert_eq!(lex.next(), None); } + +#[test] +fn test_text_predicate_tokens() { + let mut lex = lexer("text()[1]"); + assert_eq!(lex.next(), Some((Token::Text, (0..4)))); + assert_eq!(lex.next(), Some((Token::LeftParen, (4..5)))); + assert_eq!(lex.next(), Some((Token::RightParen, (5..6)))); + assert_eq!(lex.next(), Some((Token::LeftBracket, (6..7)))); + assert_eq!(lex.next(), Some((Token::IntegerLiteral(ibig!(1)), (7..8)))); + assert_eq!(lex.next(), Some((Token::RightBracket, (8..9)))); + assert_eq!(lex.next(), None); +} diff --git a/xee-xpath/src/documents.rs b/xee-xpath/src/documents.rs index e82c710f3..562a3b3ea 100644 --- a/xee-xpath/src/documents.rs +++ b/xee-xpath/src/documents.rs @@ -1,6 +1,7 @@ use iri_string::types::IriStr; use xee_interpreter::{ context::DocumentsRef, + context::TypeTableRef, xml::{DocumentHandle, DocumentsError}, }; use xot::Xot; @@ -16,6 +17,7 @@ use xot::Xot; pub struct Documents { pub(crate) xot: Xot, pub(crate) documents: DocumentsRef, + pub(crate) type_table: TypeTableRef, } impl Documents { @@ -24,6 +26,7 @@ impl Documents { Self { xot: Xot::new(), documents: DocumentsRef::new(), + type_table: TypeTableRef::new(), } } @@ -61,6 +64,11 @@ impl Documents { &self.documents } + /// Get a reference to the schema type table. + pub fn type_table(&self) -> &TypeTableRef { + &self.type_table + } + /// Get a reference to the Xot arena pub fn xot(&self) -> &Xot { &self.xot diff --git a/xee-xslt-ast/src/ast_core.rs b/xee-xslt-ast/src/ast_core.rs index a2056c9af..4d71d0e54 100644 --- a/xee-xslt-ast/src/ast_core.rs +++ b/xee-xslt-ast/src/ast_core.rs @@ -748,6 +748,12 @@ impl From for OverrideContent { } } +impl From for Declaration { + fn from(i: Function) -> Self { + Declaration::Function(Box::new(i)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub enum Streamability { @@ -802,6 +808,12 @@ pub struct Import { pub span: Span, } +impl From for Declaration { + fn from(i: Import) -> Self { + Declaration::Import(Box::new(i)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct ImportSchema { @@ -821,6 +833,12 @@ pub struct Include { pub span: Span, } +impl From for Declaration { + fn from(i: Include) -> Self { + Declaration::Include(Box::new(i)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct Iterate { @@ -983,6 +1001,12 @@ pub struct Mode { pub span: Span, } +impl From for Declaration { + fn from(i: Mode) -> Self { + Declaration::Mode(Box::new(i)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub enum OnNoMatch { @@ -1388,6 +1412,12 @@ impl From for OverrideContent { } } +impl From for Declaration { + fn from(i: Param) -> Self { + Declaration::Param(Box::new(i)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct PerformSort { @@ -1625,6 +1655,12 @@ pub struct Try { pub span: Span, } +impl From for SequenceConstructorItem { + fn from(i: Try) -> Self { + SequenceConstructorInstruction::Try(Box::new(i)).into() + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub enum TryCatchOrFallback { @@ -1705,6 +1741,12 @@ impl From for OverrideContent { } } +impl From for Declaration { + fn from(v: Variable) -> Self { + Declaration::Variable(Box::new(v)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct When { @@ -1829,6 +1871,7 @@ pub struct ElementNode { pub name: Name, pub attributes: Vec<(Name, ValueTemplate)>, pub sequence_constructor: SequenceConstructor, + pub type_: Option, pub span: Span, } diff --git a/xee-xslt-ast/src/attributes.rs b/xee-xslt-ast/src/attributes.rs index d345ecc11..9232f2f92 100644 --- a/xee-xslt-ast/src/attributes.rs +++ b/xee-xslt-ast/src/attributes.rs @@ -122,18 +122,40 @@ impl<'a> Attributes<'a> { pub(crate) fn validate_unseen(&self) -> Result<(), AttributeError> { let unseen_attributes = self.unseen_attributes(); - if !unseen_attributes.is_empty() { - return Err(self.content.state.attribute_unexpected( - self.content.node, - unseen_attributes[0], - "unexpected attribute", - )); + if unseen_attributes.is_empty() { + return Ok(()); + } + let element_in_xsl = self.in_xsl_namespace(); + let xsl_ns_uri = self + .content + .state + .xot + .namespace_str(self.content.state.names.xsl_ns); + for name in unseen_attributes { + let ns_uri = self.content.state.xot.uri_str(name); + let is_xsl = ns_uri == xsl_ns_uri; + let is_null = ns_uri.is_empty(); + let should_error = if element_in_xsl { + // On XSLT elements, only attributes in null/XSLT namespaces are constrained. + is_null || is_xsl + } else { + // On literal result elements, only XSLT-namespaced attributes are constrained. + is_xsl + }; + if should_error { + return Err(self.content.state.attribute_unexpected( + self.content.node, + name, + "unexpected attribute", + )); + } } Ok(()) } fn _boolean(s: &str, span: Span) -> Result { - match s { + let value = s.trim(); + match value { "yes" | "true" | "1" => Ok(true), "no" | "false" | "0" => Ok(false), _ => Err(AttributeError::Invalid { diff --git a/xee-xslt-ast/src/combinator.rs b/xee-xslt-ast/src/combinator.rs index 2caf79bcf..7fb737136 100644 --- a/xee-xslt-ast/src/combinator.rs +++ b/xee-xslt-ast/src/combinator.rs @@ -443,13 +443,14 @@ impl, PB: NodeParser> NodeParser for OrParser Result<(V, Option)> { - // try the first parser, if that works, return result - // if it isn't working, try the other parser - let r = self.first.parse_next(node, state, context); - if r.is_ok() { - r - } else { - self.second.parse_next(node, state, context) + // Only fall back when the first parser didn't match the node. + // Attribute/validation errors should propagate instead of being masked. + match self.first.parse_next(node, state, context) { + Ok(result) => Ok(result), + Err(ElementError::Unexpected { .. } | ElementError::UnexpectedEnd) => { + self.second.parse_next(node, state, context) + } + Err(e) => Err(e), } } } diff --git a/xee-xslt-ast/src/context.rs b/xee-xslt-ast/src/context.rs index fdbcb0655..483f2c9a1 100644 --- a/xee-xslt-ast/src/context.rs +++ b/xee-xslt-ast/src/context.rs @@ -1,6 +1,6 @@ use std::str::FromStr; -use ahash::{HashMap, HashMapExt, HashSet, HashSetExt}; +use ahash::{HashSet, HashSetExt}; use rust_decimal::Decimal; use xee_xpath_ast::{ast as xpath_ast, VariableNames, XPathParserContext}; use xee_xpath_ast::{Namespaces, FN_NAMESPACE}; @@ -171,7 +171,7 @@ impl Context { } pub(crate) fn namespaces<'a>(&'a self, state: &'a State) -> Namespaces { - let mut namespaces = HashMap::new(); + let mut namespaces = Namespaces::default_namespaces(); for (prefix, ns) in &self.prefixes { let prefix = state.xot.prefix_str(*prefix); let uri = state.xot.namespace_str(*ns); diff --git a/xee-xslt-ast/src/instruction.rs b/xee-xslt-ast/src/instruction.rs index 7db34b7c9..4a1c12cc6 100644 --- a/xee-xslt-ast/src/instruction.rs +++ b/xee-xslt-ast/src/instruction.rs @@ -1,6 +1,7 @@ use std::sync::OnceLock; use xee_name::{Namespaces, VariableNames}; use xot::Node; +use xot::xmlname::NameStrInfo; use xee_xpath_ast::ast as xpath_ast; @@ -124,10 +125,14 @@ impl InstructionParser for ast::Declaration { impl InstructionParser for ast::ElementNode { fn parse(content: &Content, attributes: &Attributes) -> Result { let mut element_attributes = Vec::new(); + let mut type_ = None; for key in content.state.xot.attributes(content.node).keys() { let name = content.state.xot.name_ref(key, content.node)?; // if any name is in the xsl namespace, we skip it if name.namespace_id() == content.state.names.xsl_ns { + if name.local_name() == "type" { + type_ = attributes.optional(key, attributes.eqname())?; + } continue; } let value = attributes.required(key, attributes.value_template(attributes.string()))?; @@ -143,6 +148,7 @@ impl InstructionParser for ast::ElementNode { attributes: element_attributes, span: content.span()?, sequence_constructor: content.sequence_constructor()?, + type_, }) } } @@ -1401,6 +1407,17 @@ impl InstructionParser for ast::Override { impl InstructionParser for ast::Param { fn parse(content: &Content, attributes: &Attributes) -> Result { let names = &content.state.names; + if attributes + .content + .xot_attributes() + .contains_key(names.visibility) + { + return Err(attributes + .content + .state + .attribute_unexpected(attributes.content.node, names.visibility, "unexpected attribute") + .into()); + } Ok(ast::Param { name: attributes.required(names.name, attributes.eqname())?, select: attributes.optional(names.select, attributes.xpath())?, @@ -1639,7 +1656,69 @@ impl InstructionParser for ast::Transform { } } -// TODO: xsl:try +impl InstructionParser for ast::Try { + fn parse(content: &Content, attributes: &Attributes) -> Result { + let names = &content.state.names; + let select = attributes.optional(names.select, attributes.xpath())?; + let rollback_output = attributes.optional(names.rollback_output, attributes.boolean())?; + let span = content.span()?; + + let mut sequence_items = Vec::new(); + let mut catch = None; + let mut catches = Vec::new(); + + let parse_sequence_constructor = sequence_constructor(); + let mut next = content.state.xot.first_child(content.node); + + while let Some(node) = next { + if let Some(element) = content.state.xot.element(node) { + let name = element.name(); + if name == names.xsl_catch { + let parsed = content.with_node(node).parse_element(element, |attributes| { + ast::Catch::parse_and_validate(attributes) + })?; + if catch.is_none() { + catch = Some(parsed); + } else { + catches.push(ast::TryCatchOrFallback::Catch(parsed)); + } + next = content.state.next(node); + continue; + } + if name == names.xsl_fallback { + let parsed = content.with_node(node).parse_element(element, |attributes| { + ast::Fallback::parse_and_validate(attributes) + })?; + catches.push(ast::TryCatchOrFallback::Fallback(parsed)); + next = content.state.next(node); + continue; + } + } + + if catch.is_some() { + return Err(Error::Unexpected { + span: content.state.span(node).ok_or(Error::Internal)?, + }); + } + + let (items, new_next) = + parse_sequence_constructor.parse_next(Some(node), content.state, &content.context)?; + sequence_items.extend(items); + next = new_next; + } + + let catch = catch.ok_or(Error::Unexpected { span })?; + + Ok(ast::Try { + select, + rollback_output, + span, + sequence_constructor: sequence_items, + catch, + catches, + }) + } +} // TODO: xsl:use-package @@ -1728,6 +1807,17 @@ impl InstructionParser for ast::WherePopulated { impl InstructionParser for ast::WithParam { fn parse(content: &Content, attributes: &Attributes) -> Result { let names = &content.state.names; + if attributes + .content + .xot_attributes() + .contains_key(names.required) + { + return Err(attributes + .content + .state + .attribute_unexpected(attributes.content.node, names.required, "unexpected attribute") + .into()); + } Ok(ast::WithParam { name: attributes.required(names.name, attributes.eqname())?, select: attributes.optional(names.select, attributes.xpath())?, diff --git a/xee-xslt-ast/src/names.rs b/xee-xslt-ast/src/names.rs index 513e6abd6..1c9191c42 100644 --- a/xee-xslt-ast/src/names.rs +++ b/xee-xslt-ast/src/names.rs @@ -108,6 +108,7 @@ impl SequenceConstructorName { ast::SourceDocument::parse_sequence_constructor_item(attributes) } SequenceConstructorName::Text => ast::Text::parse_sequence_constructor_item(attributes), + SequenceConstructorName::Try => ast::Try::parse_sequence_constructor_item(attributes), SequenceConstructorName::ValueOf => { ast::ValueOf::parse_sequence_constructor_item(attributes) } @@ -142,6 +143,12 @@ impl DeclarationName { DeclarationName::Accumulator => ast::Accumulator::parse_declaration(attributes), DeclarationName::Template => ast::Template::parse_declaration(attributes), DeclarationName::Output => ast::Output::parse_declaration(attributes), + DeclarationName::Function => ast::Function::parse_declaration(attributes), + DeclarationName::Mode => ast::Mode::parse_declaration(attributes), + DeclarationName::Param => ast::Param::parse_declaration(attributes), + DeclarationName::Variable => ast::Variable::parse_declaration(attributes), + DeclarationName::Import => ast::Import::parse_declaration(attributes), + DeclarationName::Include => ast::Include::parse_declaration(attributes), _ => Err(ElementError::Unsupported(format!( "Unsupported declaration: {:?}", &self @@ -200,6 +207,7 @@ pub(crate) struct Names { // XSL elements pub(crate) xsl_accumulator_rule: xot::NameId, pub(crate) xsl_attribute: xot::NameId, + pub(crate) xsl_catch: xot::NameId, pub(crate) xsl_fallback: xot::NameId, pub(crate) xsl_for_each: xot::NameId, pub(crate) xsl_for_each_group: xot::NameId, @@ -306,6 +314,7 @@ pub(crate) struct Names { pub(crate) regex: xot::NameId, pub(crate) required: xot::NameId, pub(crate) result_prefix: xot::NameId, + pub(crate) rollback_output: xot::NameId, pub(crate) schema_aware: xot::NameId, pub(crate) schema_location: xot::NameId, pub(crate) select: xot::NameId, @@ -399,6 +408,7 @@ impl Names { xot.add_name_ns("apply-templates", xsl_ns), xot.add_name_ns("attribute-set", xsl_ns), xot.add_name_ns("call-template", xsl_ns), + xot.add_name_ns("catch", xsl_ns), xot.add_name_ns("character-map", xsl_ns), xot.add_name_ns("choose", xsl_ns), xot.add_name_ns("evaluate", xsl_ns), @@ -411,6 +421,7 @@ impl Names { xot.add_name_ns("override", xsl_ns), xot.add_name_ns("package", xsl_ns), xot.add_name_ns("stylesheet", xsl_ns), + xot.add_name_ns("try", xsl_ns), xot.add_name_ns("transform", xsl_ns), xot.add_name_ns("use-package", xsl_ns), ] @@ -443,6 +454,7 @@ impl Names { xsl_accumulator_rule: xot.add_name_ns("accumulator-rule", xsl_ns), xsl_attribute: xot.add_name_ns("attribute", xsl_ns), + xsl_catch: xot.add_name_ns("catch", xsl_ns), xsl_fallback: xot.add_name_ns("fallback", xsl_ns), xsl_for_each: xot.add_name_ns("for-each", xsl_ns), xsl_for_each_group: xot.add_name_ns("for-each-group", xsl_ns), @@ -548,6 +560,7 @@ impl Names { regex: xot.add_name("regex"), required: xot.add_name("required"), result_prefix: xot.add_name("result-prefix"), + rollback_output: xot.add_name("rollback-output"), schema_aware: xot.add_name("schema-aware"), schema_location: xot.add_name("schema-location"), select: xot.add_name("select"), diff --git a/xee-xslt-ast/src/parse.rs b/xee-xslt-ast/src/parse.rs index 7bd0d9168..5a4eba641 100644 --- a/xee-xslt-ast/src/parse.rs +++ b/xee-xslt-ast/src/parse.rs @@ -19,8 +19,7 @@ pub fn parse_transform(s: &str) -> Result { let mut state = State::new(xot, span_info, names); let mut xot = Xot::new(); - static_evaluate(&mut state, node, Variables::new(), &mut xot) - .map_err(|_e| Error::Unsupported(format!("Static evaluate error: {:?}", _e)))?; + static_evaluate(&mut state, node, Variables::new(), &mut xot)?; let parser = XsltParser::new(&state); parser.parse_transform(node) } diff --git a/xee-xslt-ast/tests/snapshots/snapshot_tests__attributes_on_literal_element.snap b/xee-xslt-ast/tests/snapshots/snapshot_tests__attributes_on_literal_element.snap index 9f69d0dce..37951d65e 100644 --- a/xee-xslt-ast/tests/snapshots/snapshot_tests__attributes_on_literal_element.snap +++ b/xee-xslt-ast/tests/snapshots/snapshot_tests__attributes_on_literal_element.snap @@ -1,5 +1,6 @@ --- -source: xee-xslt-ast/src/instruction.rs +source: xee-xslt-ast/tests/snapshot_tests.rs +assertion_line: 258 expression: "parse_sequence_constructor_item(r#\"

\"#)" --- Ok(Content(Element(ElementNode( @@ -27,6 +28,7 @@ Ok(Content(Element(ElementNode( )), ], sequence_constructor: [], + type_: None, span: Span( start: 1, end: 2, diff --git a/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element.snap b/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element.snap index 32bb5ac98..b185895db 100644 --- a/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element.snap +++ b/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element.snap @@ -1,5 +1,6 @@ --- -source: xee-xslt-ast/src/instruction.rs +source: xee-xslt-ast/tests/snapshot_tests.rs +assertion_line: 98 expression: "parse_sequence_constructor_item(r#\"\"#)" --- Ok(Content(Element(ElementNode( @@ -10,6 +11,7 @@ Ok(Content(Element(ElementNode( ), attributes: [], sequence_constructor: [], + type_: None, span: Span( start: 1, end: 4, diff --git a/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element_with_standard_attribute.snap b/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element_with_standard_attribute.snap index 69299e37a..77eb82ced 100644 --- a/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element_with_standard_attribute.snap +++ b/xee-xslt-ast/tests/snapshots/snapshot_tests__literal_result_element_with_standard_attribute.snap @@ -1,5 +1,6 @@ --- -source: xee-xslt-ast/src/instruction.rs +source: xee-xslt-ast/tests/snapshot_tests.rs +assertion_line: 103 expression: "parse_sequence_constructor_item(r#\"\"#)" --- Ok(Content(Element(ElementNode( @@ -10,6 +11,7 @@ Ok(Content(Element(ElementNode( ), attributes: [], sequence_constructor: [], + type_: None, span: Span( start: 1, end: 4, diff --git a/xee-xslt-ast/tests/snapshots/snapshot_tests__nested_literal_elements.snap b/xee-xslt-ast/tests/snapshots/snapshot_tests__nested_literal_elements.snap index 1d5aa5e19..d2112c270 100644 --- a/xee-xslt-ast/tests/snapshots/snapshot_tests__nested_literal_elements.snap +++ b/xee-xslt-ast/tests/snapshots/snapshot_tests__nested_literal_elements.snap @@ -1,5 +1,6 @@ --- -source: xee-xslt-ast/src/instruction.rs +source: xee-xslt-ast/tests/snapshot_tests.rs +assertion_line: 244 expression: "parse_sequence_constructor_item(r#\"

\"#)" --- Ok(Instruction(If(If( @@ -40,12 +41,14 @@ Ok(Instruction(If(If( ), attributes: [], sequence_constructor: [], + type_: None, span: Span( start: 75, end: 82, ), ))), ], + type_: None, span: Span( start: 72, end: 73, diff --git a/xee-xslt-ast/tests/snapshots/snapshot_tests__no_fn_namespace_by_default.snap b/xee-xslt-ast/tests/snapshots/snapshot_tests__no_fn_namespace_by_default.snap index 1d42cb76e..edd6751d9 100644 --- a/xee-xslt-ast/tests/snapshots/snapshot_tests__no_fn_namespace_by_default.snap +++ b/xee-xslt-ast/tests/snapshots/snapshot_tests__no_fn_namespace_by_default.snap @@ -1,12 +1,34 @@ --- -source: xee-xslt-ast/src/instruction.rs +source: xee-xslt-ast/tests/snapshot_tests.rs +assertion_line: 110 expression: "parse_sequence_constructor_item(r#\"Hello\"#)" --- -Err(Attribute(XPathParser(UnknownPrefix( - span: SimpleSpan( - start: 63, - end: 70, - context: (), +Ok(Instruction(If(If( + test: Expression( + xpath: XPath(Expr([ + Path(PathExpr( + steps: [ + PrimaryExpr(FunctionCall(FunctionCall( + name: OwnedName( + local_name_str: "true", + namespace_str: "http://www.w3.org/2005/xpath-functions", + prefix_str: "fn", + ), + arguments: [], + ))), + ], + )), + ])), + span: Span( + start: 63, + end: 72, + ), + ), + sequence_constructor: [ + Content(Text("Hello")), + ], + span: Span( + start: 1, + end: 7, ), - prefix: "fn", )))) diff --git a/xee-xslt-ast/tests/snapshots/snapshot_tests__sequence_constructor_nested_in_literal_element.snap b/xee-xslt-ast/tests/snapshots/snapshot_tests__sequence_constructor_nested_in_literal_element.snap index ca1318c41..ec4c3b008 100644 --- a/xee-xslt-ast/tests/snapshots/snapshot_tests__sequence_constructor_nested_in_literal_element.snap +++ b/xee-xslt-ast/tests/snapshots/snapshot_tests__sequence_constructor_nested_in_literal_element.snap @@ -1,5 +1,6 @@ --- -source: xee-xslt-ast/src/instruction.rs +source: xee-xslt-ast/tests/snapshot_tests.rs +assertion_line: 251 expression: "parse_sequence_constructor_item(r#\"

foo

\"#)" --- Ok(Instruction(If(If( @@ -62,6 +63,7 @@ Ok(Instruction(If(If( ), ))), ], + type_: None, span: Span( start: 72, end: 73, diff --git a/xee-xslt-compiler/src/ast_ir.rs b/xee-xslt-compiler/src/ast_ir.rs index d280708ae..c8723b46a 100644 --- a/xee-xslt-compiler/src/ast_ir.rs +++ b/xee-xslt-compiler/src/ast_ir.rs @@ -1,10 +1,14 @@ -use ahash::HashSetExt; +use ahash::{HashMap, HashMapExt, HashSet, HashSetExt}; +use std::path::{Path, PathBuf}; use xee_name::{Name, Namespaces, FN_NAMESPACE}; -use xee_interpreter::{context::StaticContext, error, interpreter, sequence::QNameOrString}; +use xee_interpreter::{context::StaticContext, declaration::CatchError, error, interpreter, sequence::QNameOrString}; use xee_ir::{compile_xslt, ir, Bindings, Variables}; -use xee_xpath_ast::{ast as xpath_ast, pattern::transform_pattern, span::Spanned}; +use xee_xpath_ast::{ast as xpath_ast, parse_name, pattern::transform_pattern, span::Spanned}; +use xee_schema_type::Xs; +use xee_xpath_compiler::UserFunctions; use xee_xslt_ast::{ast, parse_transform}; +use xee_xslt_ast::error::{AttributeError, ElementError}; use xot::xmlname::NameStrInfo; use crate::{default_declarations::text_only_copy_declarations, priority::default_priority}; @@ -12,34 +16,206 @@ use crate::{default_declarations::text_only_copy_declarations, priority::default struct IrConverter<'a> { variables: Variables, static_context: &'a StaticContext, + global_params: Vec, + global_param_lookup: HashMap, + function_lookup: HashMap<(Name, u8), usize>, + user_functions: Option, } +#[derive(Debug, Clone)] +struct DeclarationWithImport { + declaration: ast::Declaration, + import_level: u32, + is_builtin: bool, +} + +#[allow(dead_code)] pub fn compile( transform: ast::Transform, static_context: StaticContext, ) -> error::SpannedResult { - let mut ir_converter = IrConverter::new(&static_context); - let declarations = ir_converter.transform(&transform)?; - compile_xslt(declarations, static_context) + let declarations = transform + .declarations + .into_iter() + .map(|declaration| DeclarationWithImport { + declaration, + import_level: 0, + is_builtin: false, + }) + .collect::>(); + compile_with_imports(declarations, static_context) } pub fn parse( static_context: StaticContext, xslt: &str, +) -> error::SpannedResult { + parse_with_base(static_context, xslt, None) +} + +pub fn parse_with_base( + static_context: StaticContext, + xslt: &str, + base_path: Option<&Path>, ) -> error::SpannedResult { let transform = parse_transform(xslt); // TODO: better error handling - let mut transform = match transform { + let transform = match transform { Ok(transform) => transform, - Err(_e) => { - return Err(error::Error::Unsupported(format!("Failed parsing XSLT: {:?}", _e)).into()); + Err(err) => { + return Err(map_parse_error(err).into()); + } + }; + let mut declarations = if let Some(base_path) = base_path { + let base_dir = base_path.parent().unwrap_or_else(|| Path::new(".")); + let mut visited = HashSet::new(); + resolve_imports(transform, base_dir, &mut visited, 0)? + } else { + transform + .declarations + .into_iter() + .map(|declaration| DeclarationWithImport { + declaration, + import_level: 0, + is_builtin: false, + }) + .collect() + }; + let max_import_level = declarations + .iter() + .map(|decl| decl.import_level) + .max() + .unwrap_or(0); + let default_import_level = max_import_level.saturating_add(1); + let mut default_declarations = text_only_copy_declarations() + .unwrap() + .into_iter() + .map(|declaration| DeclarationWithImport { + declaration, + import_level: default_import_level, + is_builtin: true, + }) + .collect::>(); + default_declarations.append(&mut declarations); + compile_with_imports(default_declarations, static_context) +} + +fn compile_with_imports( + declarations: Vec, + static_context: StaticContext, +) -> error::SpannedResult { + let mut ir_converter = IrConverter::new(&static_context); + let declarations = ir_converter.transform_with_imports(&declarations)?; + compile_xslt(declarations, static_context) +} + +fn resolve_imports( + transform: ast::Transform, + base_dir: &Path, + visited: &mut HashSet, + import_level: u32, +) -> error::SpannedResult> { + let mut import_decls = Vec::new(); + let mut local_decls = Vec::new(); + + for decl in transform.declarations { + match decl { + ast::Declaration::Import(import) => { + let import_path = resolve_import_path(&import.href, base_dir)?; + let canonical = + std::fs::canonicalize(&import_path).unwrap_or_else(|_| import_path.clone()); + if !visited.insert(canonical) { + return Err(error::Error::Unsupported( + "Circular xsl:import detected".to_string(), + ) + .into()); + } + let xslt = std::fs::read_to_string(&import_path).map_err(|e| { + error::Error::Unsupported(format!( + "Failed to read xsl:import href '{}': {}", + import.href, e + )) + })?; + let import_transform = parse_transform(&xslt).map_err(map_parse_error)?; + let import_base_dir = import_path.parent().unwrap_or(base_dir); + let import_transform = + resolve_imports(import_transform, import_base_dir, visited, import_level + 1)?; + import_decls.extend(import_transform); + } + ast::Declaration::Include(include) => { + let include_path = resolve_import_path(&include.href, base_dir)?; + let canonical = + std::fs::canonicalize(&include_path).unwrap_or_else(|_| include_path.clone()); + if !visited.insert(canonical) { + return Err(error::Error::Unsupported( + "Circular xsl:include detected".to_string(), + ) + .into()); + } + let xslt = std::fs::read_to_string(&include_path).map_err(|e| { + error::Error::Unsupported(format!( + "Failed to read xsl:include href '{}': {}", + include.href, e + )) + })?; + let include_transform = parse_transform(&xslt).map_err(map_parse_error)?; + let include_base_dir = include_path.parent().unwrap_or(base_dir); + let include_transform = resolve_imports( + include_transform, + include_base_dir, + visited, + import_level, + )?; + local_decls.extend(include_transform); + } + _ => local_decls.push(DeclarationWithImport { + declaration: decl, + import_level, + is_builtin: false, + }), } + } + + import_decls.extend(local_decls); + Ok(import_decls) +} + +fn resolve_import_path(href: &str, base_dir: &Path) -> error::SpannedResult { + let href = href.strip_prefix("file://").unwrap_or(href); + let href_path = Path::new(href); + let path = if href_path.is_absolute() { + href_path.to_path_buf() + } else { + base_dir.join(href_path) }; - // insert default rules early on in precedence order - let mut declarations = text_only_copy_declarations().unwrap(); - declarations.extend(transform.declarations); - transform.declarations = declarations; - compile(transform, static_context) + Ok(path) +} + +fn map_parse_error(err: ElementError) -> error::Error { + match err { + ElementError::Attribute(attr) => map_attribute_error(attr), + ElementError::Unexpected { .. } | ElementError::UnexpectedEnd => error::Error::XTSE0010, + ElementError::ValueTemplate(_) => error::Error::XTSE0020, + ElementError::XPathRunTime(spanned) => spanned.error, + ElementError::Unsupported(reason) => error::Error::Unsupported(reason), + ElementError::Internal => error::Error::Unsupported(String::from("Internal XSLT error")), + } +} + +fn map_attribute_error(err: AttributeError) -> error::Error { + match err { + AttributeError::NotFound { .. } => error::Error::XTSE0010, + AttributeError::Unexpected { .. } => error::Error::XTSE0090, + AttributeError::Invalid { .. } | AttributeError::InvalidEqName { .. } => { + error::Error::XTSE0020 + } + AttributeError::XPathParser(err) => { + eprintln!("XPath parse error: {err:?}"); + error::Error::XPST0003 + } + AttributeError::ValueTemplate(_) => error::Error::XPST0003, + AttributeError::Internal => error::Error::Unsupported(String::from("Internal XSLT error")), + } } impl<'a> IrConverter<'a> { @@ -47,6 +223,10 @@ impl<'a> IrConverter<'a> { IrConverter { variables: Variables::new(), static_context, + global_params: Vec::new(), + global_param_lookup: HashMap::new(), + function_lookup: HashMap::new(), + user_functions: None, } } @@ -60,7 +240,7 @@ impl<'a> IrConverter<'a> { xpath: xee_xpath_ast::ast::XPath::parse( "/", &Namespaces::default(), - &xee_name::VariableNames::new(), + &xee_name::VariableNames::default(), ) .unwrap(), span: xee_xslt_ast::ast::Span::new(0, 0), @@ -106,27 +286,44 @@ impl<'a> IrConverter<'a> { }) } - fn transform(&mut self, transform: &ast::Transform) -> error::SpannedResult { + fn transform_with_imports( + &mut self, + declarations: &[DeclarationWithImport], + ) -> error::SpannedResult { + self.collect_global_params(declarations)?; + self.collect_functions(declarations)?; let main_sequence_constructor = self.main_sequence_constructor(); - let main = self.sequence_constructor_function(&main_sequence_constructor)?; - let mut declarations = ir::Declarations::new(main); - - for declaration in &transform.declarations { - self.declaration(&mut declarations, declaration)?; + let main = self.sequence_constructor_function(&main_sequence_constructor, &[])?; + let mut ir_declarations = ir::Declarations::new(main); + + for declaration in declarations { + self.declaration_with_import( + &mut ir_declarations, + &declaration.declaration, + declaration.import_level, + declaration.is_builtin, + )?; } - Ok(declarations) + ir_declarations.global_params = self.global_params.clone(); + Ok(ir_declarations) } - fn declaration( + fn declaration_with_import( &mut self, declarations: &mut ir::Declarations, declaration: &ast::Declaration, + import_level: u32, + is_builtin: bool, ) -> error::SpannedResult<()> { use ast::Declaration::*; match declaration { - Template(template) => self.template(declarations, template), + Template(template) => self.template(declarations, template, import_level, is_builtin), Mode(mode) => self.mode(declarations, mode), Output(output) => self.output(declarations, output), + Param(param) => self.param_declaration(param), + Variable(variable) => self.variable_declaration(variable), + Function(function) => self.function_declaration(declarations, function), + Import(_) | Include(_) => Ok(()), _ => Err(error::Error::Unsupported(format!( "Declaration not supported: {:?}", declaration @@ -139,41 +336,149 @@ impl<'a> IrConverter<'a> { &mut self, declarations: &mut ir::Declarations, template: &ast::Template, + import_level: u32, + is_builtin: bool, ) -> error::SpannedResult<()> { + if template.match_.is_none() && template.name.is_none() { + return Err( + error::Error::Unsupported("Template without match or name".to_string()).into(), + ); + } + + let context_names = self.variables.push_context(); + let (template_params, template_ir_params) = self.template_params(&template.params)?; + let function_definition = self.sequence_constructor_function_with_context( + &context_names, + &template.sequence_constructor, + &template_ir_params, + )?; + self.variables.pop_context(); + if let Some(pattern) = &template.match_ { - let priority = if let Some(priority) = &template.priority { - *priority + let priorities = if let Some(priority) = &template.priority { + vec![(pattern.pattern.clone(), *priority)] } else { - let default_priorities = default_priority(&pattern.pattern).collect::>(); - if default_priorities.len() > 1 { - // for now, we can't deal with multiple registration yet - return Err(error::Error::Unsupported( - "Default priorities splitting not supported".to_string(), - ) - .into()); - } else { - default_priorities.first().unwrap().1 - } + default_priority(&pattern.pattern) + .map(|(p, d)| (p.into_owned(), d)) + .collect() }; - let function_definition = - self.sequence_constructor_function(&template.sequence_constructor)?; let modes = template .mode .iter() .map(Self::ast_mode_value_to_ir_mode_value) - .collect(); + .collect::>(); + + for (pattern, priority) in priorities { + declarations.rules.push(ir::Rule { + priority, + modes: modes.clone(), + import_level, + is_builtin, + pattern: transform_pattern(&pattern, |expr| self.pattern_predicate(expr))?, + function_definition: function_definition.clone(), + template_params: template_params.clone(), + }); + } + } - declarations.rules.push(ir::Rule { - priority, - modes, - pattern: transform_pattern(&pattern.pattern, |expr| self.pattern_predicate(expr))?, + if let Some(name) = &template.name { + declarations.named_templates.push(ir::NamedTemplate { + name: name.clone(), function_definition, + template_params, }); - Ok(()) - } else { - Err(error::Error::Unsupported("Named templates not supported".to_string()).into()) } + + Ok(()) + } + + fn sequence_constructor_function_with_context( + &mut self, + context_names: &ir::ContextNames, + sequence_constructor: &ast::SequenceConstructor, + extra_params: &[ir::Param], + ) -> error::SpannedResult { + let bindings = self.sequence_constructor(sequence_constructor)?; + let mut params = vec![ + ir::Param { + name: context_names.item.clone(), + type_: None, + }, + ir::Param { + name: context_names.position.clone(), + type_: None, + }, + ir::Param { + name: context_names.last.clone(), + type_: None, + }, + ]; + params.extend(self.global_param_ir_params()); + params.extend(extra_params.iter().cloned()); + Ok(ir::FunctionDefinition { + params, + return_type: None, + body: Box::new(bindings.expr()), + }) + } + + fn function_declaration( + &mut self, + declarations: &mut ir::Declarations, + function: &ast::Function, + ) -> error::SpannedResult<()> { + if function.override_ + || function.override_extension_function + || function.new_each_time.is_some() + || function.cache + { + return Err(error::Error::Unsupported( + "Overridable or cached functions are not supported".to_string(), + ) + .into()); + } + if function.visibility.is_some() { + return Err(error::Error::Unsupported( + "Function visibility is not supported".to_string(), + ) + .into()); + } + if function.streamability.is_some() { + return Err(error::Error::Unsupported( + "Streamable functions are not supported".to_string(), + ) + .into()); + } + + let function_params = self.function_params(&function.params)?; + self.variables.push_absent_context(); + let bindings = self.sequence_constructor(&function.sequence_constructor)?; + self.variables.pop_context(); + + let mut params = function_params; + params.extend(self.global_param_ir_params()); + + let function_definition = ir::FunctionDefinition { + params, + return_type: function.as_.clone(), + body: Box::new(bindings.expr()), + }; + + let arity = function.params.len(); + if arity > u8::MAX as usize { + return Err(error::Error::Unsupported( + "Function arity too large".to_string(), + ) + .into()); + } + + declarations.functions.push(ir::FunctionBinding { + name: function.name.clone(), + arity: arity as u8, + main: function_definition, + }); + Ok(()) } fn mode( @@ -181,7 +486,17 @@ impl<'a> IrConverter<'a> { declarations: &mut ir::Declarations, mode: &ast::Mode, ) -> error::SpannedResult<()> { - declarations.modes.insert(mode.name.clone(), ir::Mode {}); + let on_no_match = mode.on_no_match.as_ref().map(|m| match m { + ast::OnNoMatch::DeepCopy => ir::OnNoMatch::DeepCopy, + ast::OnNoMatch::ShallowCopy => ir::OnNoMatch::ShallowCopy, + ast::OnNoMatch::DeepSkip => ir::OnNoMatch::DeepSkip, + ast::OnNoMatch::ShallowSkip => ir::OnNoMatch::ShallowSkip, + ast::OnNoMatch::TextOnlyCopy => ir::OnNoMatch::TextOnlyCopy, + ast::OnNoMatch::Fail => ir::OnNoMatch::Fail, + }); + declarations + .modes + .insert(mode.name.clone(), ir::Mode { on_no_match }); Ok(()) } @@ -309,11 +624,12 @@ impl<'a> IrConverter<'a> { fn sequence_constructor_function( &mut self, sequence_constructor: &ast::SequenceConstructor, + extra_params: &[ir::Param], ) -> error::SpannedResult { let context_names = self.variables.push_context(); let bindings = self.sequence_constructor(sequence_constructor)?; self.variables.pop_context(); - let params = vec![ + let mut params = vec![ ir::Param { name: context_names.item, type_: None, @@ -327,6 +643,8 @@ impl<'a> IrConverter<'a> { type_: None, }, ]; + params.extend(self.global_param_ir_params()); + params.extend(extra_params.iter().cloned()); Ok(ir::FunctionDefinition { params, return_type: None, @@ -338,27 +656,46 @@ impl<'a> IrConverter<'a> { &mut self, sequence_constructor: &[ast::SequenceConstructorItem], ) -> error::SpannedResult { - let mut items = sequence_constructor.iter(); - let left = items.next(); - if let Some(left) = left { - if let Some((name, var_bindings)) = self.variable(left)? { + if sequence_constructor.is_empty() { + let empty_sequence = self.empty_sequence(); + return Ok(Bindings::new( + self.variables + .new_binding(empty_sequence.value, empty_sequence.span), + )); + } + + for (index, item) in sequence_constructor.iter().enumerate() { + if let Some((name, var_bindings)) = self.variable(item)? { + let rest_bindings = self.sequence_constructor(&sequence_constructor[index + 1..])?; let expr = ir::Expr::Let(ir::Let { name, var_expr: Box::new(var_bindings.expr()), - return_expr: Box::new(self.sequence_constructor(items.as_slice())?.expr()), + return_expr: Box::new(rest_bindings.expr()), + }); + let let_bindings = + Bindings::new(self.variables.new_binding(expr, (0..0).into())); + if index == 0 { + return Ok(let_bindings); + } + let prefix_bindings = self.sequence_constructor_concat( + &sequence_constructor[0], + sequence_constructor[1..index].iter(), + )?; + let (left_atom, left_bindings) = prefix_bindings.atom_bindings(); + let (right_atom, right_bindings) = let_bindings.atom_bindings(); + let expr = ir::Expr::Binary(ir::Binary { + left: left_atom, + op: ir::BinaryOperator::Comma, + right: right_atom, }); - return Ok(Bindings::new( - self.variables.new_binding(expr, (0..0).into()), - )); + let binding = self.variables.new_binding_no_span(expr); + return Ok(left_bindings.concat(right_bindings).bind(binding)); } - self.sequence_constructor_concat(left, items) - } else { - let empty_sequence = self.empty_sequence(); - Ok(Bindings::new( - self.variables - .new_binding(empty_sequence.value, empty_sequence.span), - )) } + + let mut items = sequence_constructor.iter(); + let left = items.next().expect("sequence_constructor not empty"); + self.sequence_constructor_concat(left, items) } fn sequence_constructor_concat<'b>( @@ -401,9 +738,13 @@ impl<'a> IrConverter<'a> { use ast::SequenceConstructorInstruction::*; match instruction { ApplyTemplates(apply_templates) => self.apply_templates(apply_templates), + ApplyImports(apply_imports) => self.apply_imports(apply_imports), + NextMatch(next_match) => self.next_match(next_match), + CallTemplate(call_template) => self.call_template(call_template), ValueOf(value_of) => self.value_of(value_of), If(if_) => self.if_(if_), Choose(choose) => self.choose(choose), + Assert(assert_) => self.assert_(assert_), ForEach(for_each) => self.for_each(for_each), Iterate(iterate) => self.iterate(iterate), NextIteration(next_iteration) => self.next_iteration(next_iteration), @@ -411,8 +752,10 @@ impl<'a> IrConverter<'a> { Copy(copy) => self.copy(copy), CopyOf(copy_of) => self.copy_of(copy_of), Sequence(sequence) => self.sequence(sequence), + Document(document) => self.document(document), Element(element) => self.element(element), Text(text) => self.text(text), + Try(try_) => self.try_(try_), Attribute(attribute) => self.attribute(attribute), Namespace(namespace) => self.namespace(namespace), Comment(comment) => self.comment(comment), @@ -474,6 +817,16 @@ impl<'a> IrConverter<'a> { let (element_atom, mut bindings) = bindings .bind_expr_no_span(&mut self.variables, name_expr) .atom_bindings(); + if let Some(type_name) = &element_node.type_ { + let xs = self.xs_type_from_eqname(type_name, element_node.span)?; + let set_type_expr = ir::Expr::XmlSetType(ir::XmlSetType { + node: element_atom.clone(), + xs, + }); + let set_type_bindings = + bindings.bind_expr_no_span(&mut self.variables, set_type_expr); + bindings = bindings.concat(set_type_bindings); + } for (name, value) in &element_node.attributes { let (value_atom, value_bindings) = self.attribute_value_template(value)?.atom_bindings(); @@ -502,6 +855,18 @@ impl<'a> IrConverter<'a> { Ok(bindings) } + fn xs_type_from_eqname( + &self, + type_name: &ast::EqName, + span: ast::Span, + ) -> error::SpannedResult { + Xs::by_name(type_name.namespace(), type_name.local_name()).ok_or_else(|| { + let span = xpath_ast::Span::new(span.start, span.end); + error::Error::Unsupported("xsl:type only supports xs:* names".to_string()) + .with_ast_span(span) + }) + } + fn sequence_constructor_append( &mut self, element_atom: ir::AtomS, @@ -542,15 +907,287 @@ impl<'a> IrConverter<'a> { ast::ApplyTemplatesModeValue::Current => ir::ApplyTemplatesModeValue::Current, }; + let mut params = Vec::new(); + let mut bindings = bindings; + for content in &apply_templates.content { + let with_param = match content { + ast::ApplyTemplatesContent::WithParam(with_param) => with_param, + ast::ApplyTemplatesContent::Sort(_) => { + continue; + } + }; + if with_param.tunnel { + return Err(error::Error::Unsupported( + "Tunnel params are not supported".to_string(), + ) + .into()); + } + let (value_atom, value_bindings) = + self.with_param_value_atom(with_param)?.atom_bindings(); + bindings = bindings.concat(value_bindings); + params.push(ir::WithParam { + name: with_param.name.clone(), + value: value_atom, + }); + } + Ok(bindings.bind_expr_no_span( &mut self.variables, ir::Expr::ApplyTemplates(ir::ApplyTemplates { mode, select: select_atom, + params, + }), + )) + } + + fn apply_imports( + &mut self, + apply_imports: &ast::ApplyImports, + ) -> error::SpannedResult { + let mut params = Vec::new(); + let mut bindings = Bindings::empty(); + for with_param in &apply_imports.with_params { + if with_param.tunnel { + return Err(error::Error::Unsupported( + "Tunnel params are not supported".to_string(), + ) + .into()); + } + let (value_atom, value_bindings) = + self.with_param_value_atom(with_param)?.atom_bindings(); + bindings = bindings.concat(value_bindings); + params.push(ir::WithParam { + name: with_param.name.clone(), + value: value_atom, + }); + } + + Ok(bindings.bind_expr_no_span( + &mut self.variables, + ir::Expr::ApplyImports(ir::ApplyImports { params }), + )) + } + + fn next_match(&mut self, next_match: &ast::NextMatch) -> error::SpannedResult { + let mut params = Vec::new(); + let mut bindings = Bindings::empty(); + for content in &next_match.content { + let with_param = match content { + ast::NextMatchContent::WithParam(with_param) => with_param, + ast::NextMatchContent::Fallback(_) => { + return Err(error::Error::Unsupported( + "xsl:fallback not supported".to_string(), + ) + .into()); + } + }; + if with_param.tunnel { + return Err(error::Error::Unsupported( + "Tunnel params are not supported".to_string(), + ) + .into()); + } + let (value_atom, value_bindings) = + self.with_param_value_atom(with_param)?.atom_bindings(); + bindings = bindings.concat(value_bindings); + params.push(ir::WithParam { + name: with_param.name.clone(), + value: value_atom, + }); + } + + Ok(bindings.bind_expr_no_span( + &mut self.variables, + ir::Expr::NextMatch(ir::NextMatch { params }), + )) + } + + fn call_template( + &mut self, + call_template: &ast::CallTemplate, + ) -> error::SpannedResult { + let mut params = Vec::new(); + let mut bindings = Bindings::empty(); + for with_param in &call_template.with_params { + if with_param.tunnel { + return Err(error::Error::Unsupported( + "Tunnel params are not supported".to_string(), + ) + .into()); + } + let (value_atom, value_bindings) = + self.with_param_value_atom(with_param)?.atom_bindings(); + bindings = bindings.concat(value_bindings); + params.push(ir::WithParam { + name: with_param.name.clone(), + value: value_atom, + }); + } + + Ok(bindings.bind_expr_no_span( + &mut self.variables, + ir::Expr::CallTemplate(ir::CallTemplate { + name: call_template.name.clone(), + params, }), )) } + fn try_(&mut self, try_: &ast::Try) -> error::SpannedResult { + let try_body = self.select_or_sequence_constructor_function( + try_.select.as_ref(), + &try_.sequence_constructor, + )?; + + let mut catches = Vec::new(); + catches.push(self.catch_clause(&try_.catch)?); + + for entry in &try_.catches { + match entry { + ast::TryCatchOrFallback::Catch(catch) => { + catches.push(self.catch_clause(catch)?); + } + ast::TryCatchOrFallback::Fallback(_) => { + return Err(error::Error::Unsupported( + "xsl:fallback in xsl:try is not supported".to_string(), + ) + .into()); + } + } + } + + let expr = ir::Expr::TryCatch(ir::TryCatch { + try_body, + catches, + rollback_output: try_.rollback_output.unwrap_or(true), + }); + + Ok(Bindings::new(self.variables.new_binding_no_span(expr))) + } + + fn catch_clause(&mut self, catch: &ast::Catch) -> error::SpannedResult { + let errors = self.parse_catch_errors(catch.errors.as_ref())?; + let body = self.select_or_sequence_constructor_function( + catch.select.as_ref(), + &catch.sequence_constructor, + )?; + Ok(ir::CatchClause { errors, body }) + } + + fn select_or_sequence_constructor_function( + &mut self, + select: Option<&ast::Expression>, + sequence_constructor: &ast::SequenceConstructor, + ) -> error::SpannedResult { + if let Some(select) = select { + self.expression_function(select) + } else { + self.sequence_constructor_function(sequence_constructor, &[]) + } + } + + fn expression_function( + &mut self, + expression: &ast::Expression, + ) -> error::SpannedResult { + let context_names = self.variables.push_context(); + let bindings = self.expression(expression)?; + self.variables.pop_context(); + let mut params = vec![ + ir::Param { + name: context_names.item, + type_: None, + }, + ir::Param { + name: context_names.position, + type_: None, + }, + ir::Param { + name: context_names.last, + type_: None, + }, + ]; + params.extend(self.global_param_ir_params()); + Ok(ir::FunctionDefinition { + params, + return_type: None, + body: Box::new(bindings.expr()), + }) + } + + fn parse_catch_errors( + &self, + errors: Option<&Vec>, + ) -> error::SpannedResult> { + let errors = match errors { + Some(errors) if !errors.is_empty() => errors, + _ => return Ok(vec![CatchError::Any]), + }; + let mut result = Vec::with_capacity(errors.len()); + for token in errors { + result.push(self.parse_catch_error_token(token)?); + } + Ok(result) + } + + fn parse_catch_error_token(&self, token: &str) -> error::SpannedResult { + let token = token.trim(); + if token == "*" || token == "*:*" { + return Ok(CatchError::Any); + } + if let Some(local) = token.strip_prefix("*:") { + return Ok(CatchError::Local(local.to_string())); + } + if let Some(prefix) = token.strip_suffix(":*") { + let namespace = self + .static_context + .namespaces() + .by_prefix(prefix) + .ok_or_else(|| { + error::Error::Unsupported(format!( + "Unknown namespace prefix in xsl:catch errors: {token}" + )) + })?; + return Ok(CatchError::Namespace(namespace.to_string())); + } + if let Some(qname) = token.strip_prefix("Q{") { + if let Some(end) = qname.find('}') { + let namespace = &qname[..end]; + let local = &qname[end + 1..]; + if local == "*" { + return Ok(CatchError::Namespace(namespace.to_string())); + } + if local.is_empty() { + return Err(error::Error::Unsupported(format!( + "Invalid xsl:catch errors token: {token}" + )) + .into()); + } + return Ok(CatchError::Name(Name::new( + local.to_string(), + namespace.to_string(), + String::new(), + ))); + } + } + if !token.contains(':') { + return Ok(CatchError::Name(Name::new( + token.to_string(), + String::new(), + String::new(), + ))); + } + + match parse_name(token, self.static_context.namespaces()) { + Ok(spanned) => Ok(CatchError::Name(spanned.value)), + Err(_) => Err(error::Error::Unsupported(format!( + "Invalid xsl:catch errors token: {token}" + )) + .into()), + } + } + fn select_or_sequence_constructor( &mut self, instruction: &impl ast::SelectOrSequenceConstructor, @@ -562,6 +1199,88 @@ impl<'a> IrConverter<'a> { } } + fn sequence_constructor_document( + &mut self, + sequence_constructor: &ast::SequenceConstructor, + ) -> error::SpannedResult { + let doc_expr = ir::Expr::XmlDocument(ir::XmlRoot {}); + let (doc_atom, mut bindings) = Bindings::new(self.variables.new_binding_no_span(doc_expr)) + .atom_bindings(); + if !sequence_constructor.is_empty() { + let (child_atom, child_bindings) = + self.sequence_constructor(sequence_constructor)?.atom_bindings(); + bindings = bindings.concat(child_bindings); + let append = ir::Expr::XmlAppend(ir::XmlAppend { + parent: doc_atom, + child: child_atom, + }); + bindings = bindings.bind_expr_no_span(&mut self.variables, append); + } + Ok(bindings) + } + + fn sequence_constructor_has_content(sequence_constructor: &ast::SequenceConstructor) -> bool { + sequence_constructor.iter().any(|item| { + matches!(item, ast::SequenceConstructorItem::Content(_)) + }) + } + + fn with_param_value_atom( + &mut self, + with_param: &ast::WithParam, + ) -> error::SpannedResult { + if let Some(as_) = &with_param.as_ { + if let Some(occurrence) = Self::string_sequence_occurrence(as_) { + match occurrence { + xpath_ast::Occurrence::One => { + return self.select_or_sequence_constructor_simple_content(with_param); + } + xpath_ast::Occurrence::Option => { + if with_param.select.is_some() + || !with_param.sequence_constructor.is_empty() + { + return self.select_or_sequence_constructor_simple_content(with_param); + } + } + _ => {} + } + } + } + if with_param.select.is_none() + && with_param.sequence_constructor.is_empty() + && with_param.as_.is_none() + { + return Ok(Bindings::new(self.variables.new_binding_no_span(ir::Expr::Atom( + Spanned::new( + ir::Atom::Const(ir::Const::String(String::new())), + (0..0).into(), + ), + )))); + } + if with_param.select.is_none() + && !with_param.sequence_constructor.is_empty() + && Self::sequence_constructor_has_content(&with_param.sequence_constructor) + { + return self.sequence_constructor_document(&with_param.sequence_constructor); + } + self.select_or_sequence_constructor(with_param) + } + + fn string_sequence_occurrence( + sequence_type: &xpath_ast::SequenceType, + ) -> Option { + match sequence_type { + xpath_ast::SequenceType::Item(item) => { + matches!( + item.item_type, + xpath_ast::ItemType::AtomicOrUnionType(Xs::String) + ) + .then_some(item.occurrence) + } + _ => None, + } + } + fn select_or_sequence_constructor_simple_content( &mut self, instruction: &impl ast::SelectOrSequenceConstructor, @@ -612,6 +1331,118 @@ impl<'a> IrConverter<'a> { )) } + fn assert_error_expr(&mut self, assert_: &ast::Assert) -> error::SpannedResult { + let (code_atom, code_bindings) = self.assert_error_code(assert_)?.atom_bindings(); + let (message_atom, message_bindings) = self.assert_message(assert_)?.atom_bindings(); + let bindings = code_bindings.concat(message_bindings); + let error_atom = self.static_function_atom("error", FN_NAMESPACE, 2); + let expr = ir::Expr::FunctionCall(ir::FunctionCall { + atom: Spanned::new(error_atom, (0..0).into()), + args: vec![code_atom, message_atom], + }); + Ok(bindings + .bind_expr_no_span(&mut self.variables, expr) + .expr()) + } + + fn assert_error_code(&mut self, assert_: &ast::Assert) -> error::SpannedResult { + let (namespace, local) = if let Some(error_code) = &assert_.error_code { + let literal = self + .value_template_literal(error_code) + .ok_or_else(|| { + error::Error::Unsupported( + "xsl:assert error-code must be a literal in this implementation" + .to_string(), + ) + })?; + self.parse_error_code_literal(&literal)? + } else { + ( + "http://www.w3.org/2005/xqt-errors".to_string(), + "XTMM9001".to_string(), + ) + }; + + Ok(self.qname_expr(&namespace, &local)) + } + + fn assert_message(&mut self, assert_: &ast::Assert) -> error::SpannedResult { + let (select_atom, bindings) = if let Some(select) = &assert_.select { + self.expression(select)?.atom_bindings() + } else { + self.sequence_constructor(&assert_.sequence_constructor)? + .atom_bindings() + }; + + let expr = self.simple_content_expr(select_atom, self.space_separator_atom()); + Ok(bindings.bind_expr_no_span(&mut self.variables, expr)) + } + + fn qname_expr(&mut self, namespace: &str, qname: &str) -> Bindings { + let namespace_atom = Spanned::new( + ir::Atom::Const(ir::Const::String(namespace.to_string())), + (0..0).into(), + ); + let qname_atom = Spanned::new( + ir::Atom::Const(ir::Const::String(qname.to_string())), + (0..0).into(), + ); + let qname_fn = self.static_function_atom("QName", FN_NAMESPACE, 2); + let expr = ir::Expr::FunctionCall(ir::FunctionCall { + atom: Spanned::new(qname_fn, (0..0).into()), + args: vec![namespace_atom, qname_atom], + }); + Bindings::new(self.variables.new_binding_no_span(expr)) + } + + fn value_template_literal(&self, template: &ast::ValueTemplate) -> Option + where + T: Clone + PartialEq + Eq, + { + let mut out = String::new(); + for item in &template.template { + match item { + ast::ValueTemplateItem::String { text, .. } => out.push_str(text), + ast::ValueTemplateItem::Curly { c } => out.push(*c), + ast::ValueTemplateItem::Value { .. } => return None, + } + } + Some(out) + } + + fn parse_error_code_literal(&self, value: &str) -> error::SpannedResult<(String, String)> { + if let Some(rest) = value.strip_prefix("Q{") { + if let Some(end) = rest.find('}') { + let namespace = rest[..end].to_string(); + let local = rest[end + 1..].to_string(); + if local.is_empty() { + return Err(error::Error::Unsupported(format!( + "Invalid error-code EQName: {}", + value + )) + .into()); + } + return Ok((namespace, local)); + } + } + + let local = value + .rsplit_once(':') + .map(|(_, local)| local) + .unwrap_or(value) + .to_string(); + + if local.is_empty() { + return Err(error::Error::Unsupported(format!( + "Invalid error-code EQName: {}", + value + )) + .into()); + } + + Ok((String::new(), local)) + } + fn attribute_value_template( &mut self, value_template: &ast::ValueTemplate, @@ -725,6 +1556,26 @@ impl<'a> IrConverter<'a> { self.choose_when_otherwise(&choose.when, choose.otherwise.as_ref()) } + fn assert_(&mut self, assert_: &ast::Assert) -> error::SpannedResult { + if !self.static_context.assertions_enabled() { + let empty = self.empty_sequence().value; + return Ok(Bindings::empty().bind_expr_no_span( + &mut self.variables, + empty, + )); + } + let (condition, bindings) = self.expression(&assert_.test)?.atom_bindings(); + let error_expr = self.assert_error_expr(assert_)?; + + let expr = ir::Expr::If(ir::If { + condition, + then: Box::new(self.empty_sequence()), + else_: Box::new(error_expr), + }); + + Ok(bindings.bind_expr_no_span(&mut self.variables, expr)) + } + fn choose_when_otherwise( &mut self, when: &[ast::When], @@ -864,9 +1715,21 @@ impl<'a> IrConverter<'a> { let expr = ir::Expr::CopyShallow(ir::CopyShallow { select: context_atom, }); - let (copy_atom, bindings) = bindings + let (mut copy_atom, mut bindings) = bindings .bind_expr_no_span(&mut self.variables, expr) .atom_bindings(); + if let Some(type_name) = ©.type_ { + let xs = self.xs_type_from_eqname(type_name, copy.span)?; + let set_type_expr = ir::Expr::XmlSetType(ir::XmlSetType { + node: copy_atom.clone(), + xs, + }); + let (typed_atom, set_type_bindings) = bindings + .bind_expr_no_span(&mut self.variables, set_type_expr) + .atom_bindings(); + bindings = set_type_bindings; + copy_atom = typed_atom; + } // if it is an element or document, // execute sequence constructor @@ -922,7 +1785,51 @@ impl<'a> IrConverter<'a> { fn copy_of(&mut self, copy_of: &ast::CopyOf) -> error::SpannedResult { let (atom, bindings) = self.expression(©_of.select)?.atom_bindings(); let copy_deep_expr = ir::Expr::CopyDeep(ir::CopyDeep { select: atom }); - Ok(bindings.bind_expr_no_span(&mut self.variables, copy_deep_expr)) + let (copy_atom, mut bindings) = bindings + .bind_expr_no_span(&mut self.variables, copy_deep_expr) + .atom_bindings(); + if let Some(type_name) = ©_of.type_ { + let xs = self.xs_type_from_eqname(type_name, copy_of.span)?; + let set_type_expr = ir::Expr::XmlSetType(ir::XmlSetType { + node: copy_atom.clone(), + xs, + }); + let (_typed_atom, set_type_bindings) = bindings + .bind_expr_no_span(&mut self.variables, set_type_expr) + .atom_bindings(); + bindings = set_type_bindings; + } + Ok(bindings) + } + + fn document(&mut self, document: &ast::Document) -> error::SpannedResult { + let doc_expr = ir::Expr::XmlDocument(ir::XmlRoot {}); + let (mut doc_atom, mut bindings) = + Bindings::new(self.variables.new_binding_no_span(doc_expr)).atom_bindings(); + if let Some(type_name) = &document.type_ { + let xs = self.xs_type_from_eqname(type_name, document.span)?; + let set_type_expr = ir::Expr::XmlSetType(ir::XmlSetType { + node: doc_atom.clone(), + xs, + }); + let (typed_atom, set_type_bindings) = bindings + .bind_expr_no_span(&mut self.variables, set_type_expr) + .atom_bindings(); + bindings = set_type_bindings; + doc_atom = typed_atom; + } + if !document.sequence_constructor.is_empty() { + let (child_atom, child_bindings) = self + .sequence_constructor(&document.sequence_constructor)? + .atom_bindings(); + bindings = bindings.concat(child_bindings); + let append = ir::Expr::XmlAppend(ir::XmlAppend { + parent: doc_atom, + child: child_atom, + }); + bindings = bindings.bind_expr_no_span(&mut self.variables, append); + } + Ok(bindings) } fn sequence(&mut self, sequence: &ast::Sequence) -> error::SpannedResult { @@ -978,9 +1885,21 @@ impl<'a> IrConverter<'a> { .atom_bindings(); let expr = ir::Expr::XmlElement(ir::XmlElement { name: name_atom }); - let (element_atom, bindings) = bindings + let (mut element_atom, mut bindings) = bindings .bind_expr_no_span(&mut self.variables, expr) .atom_bindings(); + if let Some(type_name) = &element.type_ { + let xs = self.xs_type_from_eqname(type_name, element.span)?; + let set_type_expr = ir::Expr::XmlSetType(ir::XmlSetType { + node: element_atom.clone(), + xs, + }); + let (typed_atom, set_type_bindings) = bindings + .bind_expr_no_span(&mut self.variables, set_type_expr) + .atom_bindings(); + bindings = set_type_bindings; + element_atom = typed_atom; + } let sequence_constructor_bindings = self.sequence_constructor_append(element_atom, &element.sequence_constructor)?; Ok(bindings.concat(sequence_constructor_bindings)) @@ -1007,13 +1926,25 @@ impl<'a> IrConverter<'a> { )? .atom_bindings(); let bindings = name_bindings.concat(text_bindings); - Ok(bindings.bind_expr_no_span( - &mut self.variables, - ir::Expr::XmlAttribute(ir::XmlAttribute { - name: name_atom, - value: text_atom, - }), - )) + let attribute_expr = ir::Expr::XmlAttribute(ir::XmlAttribute { + name: name_atom, + value: text_atom, + }); + let (attribute_atom, mut bindings) = bindings + .bind_expr_no_span(&mut self.variables, attribute_expr) + .atom_bindings(); + if let Some(type_name) = &attribute.type_ { + let xs = self.xs_type_from_eqname(type_name, attribute.span)?; + let set_type_expr = ir::Expr::XmlSetType(ir::XmlSetType { + node: attribute_atom.clone(), + xs, + }); + let (_typed_atom, set_type_bindings) = bindings + .bind_expr_no_span(&mut self.variables, set_type_expr) + .atom_bindings(); + bindings = set_type_bindings; + } + Ok(bindings) } fn namespace(&mut self, namespace: &ast::Namespace) -> error::SpannedResult { @@ -1073,8 +2004,15 @@ impl<'a> IrConverter<'a> { } fn xpath(&mut self, xpath: &xee_xpath_ast::ast::ExprS) -> error::SpannedResult { - let mut ir_converter = - xee_xpath_compiler::IrConverter::new(&mut self.variables, self.static_context); + let mut ir_converter = if let Some(user_functions) = &self.user_functions { + xee_xpath_compiler::IrConverter::new_with_user_functions( + &mut self.variables, + self.static_context, + user_functions.clone(), + ) + } else { + xee_xpath_compiler::IrConverter::new(&mut self.variables, self.static_context) + }; ir_converter.expr(xpath) } @@ -1096,7 +2034,7 @@ impl<'a> IrConverter<'a> { }); let bindings = bindings.bind_expr(&mut self.variables, Spanned::new(filter, (0..0).into())); - let params = vec![ + let mut params = vec![ ir::Param { name: context_names.item, type_: None, @@ -1110,6 +2048,7 @@ impl<'a> IrConverter<'a> { type_: None, }, ]; + params.extend(self.global_param_ir_params()); Ok(ir::FunctionDefinition { params, @@ -1117,4 +2056,268 @@ impl<'a> IrConverter<'a> { body: Box::new(bindings.expr()), }) } + + fn template_params( + &mut self, + params: &[ast::Param], + ) -> error::SpannedResult<(Vec, Vec)> { + let mut template_params = Vec::new(); + let mut template_ir_params = Vec::new(); + let mut seen = HashSet::new(); + + for param in params { + if param.static_ { + return Err(error::Error::Unsupported( + "Static template params are not supported".to_string(), + ) + .into()); + } + if param.tunnel { + return Err(error::Error::Unsupported( + "Tunnel params are not supported".to_string(), + ) + .into()); + } + if !seen.insert(param.name.clone()) { + return Err(error::Error::Unsupported( + "Duplicate template param names are not supported".to_string(), + ) + .into()); + } + if param.required + && (param.select.is_some() || !param.sequence_constructor.is_empty()) + { + return Err(error::SpannedError { + error: error::Error::XTSE0010, + span: Some((param.span.start..param.span.end).into()), + }); + } + + let default_expr = if let Some(select) = ¶m.select { + Some(self.expression(select)?.expr()) + } else if !param.sequence_constructor.is_empty() { + if Self::sequence_constructor_has_content(¶m.sequence_constructor) { + Some( + self.sequence_constructor_document(¶m.sequence_constructor)? + .expr(), + ) + } else { + Some(self.sequence_constructor(¶m.sequence_constructor)?.expr()) + } + } else { + None + }; + let var_name = self.variables.new_var_name(¶m.name); + template_params.push(ir::TemplateParam { + name: param.name.clone(), + var_name: var_name.clone(), + required: param.required, + default_expr, + type_: param.as_.clone(), + }); + template_ir_params.push(ir::Param { + name: var_name, + type_: param.as_.clone(), + }); + } + + Ok((template_params, template_ir_params)) + } + + fn function_params(&mut self, params: &[ast::Param]) -> error::SpannedResult> { + let mut function_params = Vec::new(); + let mut seen = HashSet::new(); + + for param in params { + if param.static_ { + return Err(error::Error::Unsupported( + "Static function params are not supported".to_string(), + ) + .into()); + } + if param.tunnel { + return Err(error::Error::Unsupported( + "Tunnel function params are not supported".to_string(), + ) + .into()); + } + if param.select.is_some() || !param.sequence_constructor.is_empty() { + return Err(error::Error::Unsupported( + "Function param defaults are not supported".to_string(), + ) + .into()); + } + if !seen.insert(param.name.clone()) { + return Err(error::Error::Unsupported( + "Duplicate function param names are not supported".to_string(), + ) + .into()); + } + + let var_name = self.variables.new_var_name(¶m.name); + function_params.push(ir::Param { + name: var_name, + type_: param.as_.clone(), + }); + } + + Ok(function_params) + } + + fn collect_global_params( + &mut self, + declarations: &[DeclarationWithImport], + ) -> error::SpannedResult<()> { + for declaration in declarations { + match &declaration.declaration { + ast::Declaration::Param(param) => { + if param.static_ { + continue; + } + if self.global_param_lookup.contains_key(¶m.name) { + continue; + } + let var_name = self.variables.new_var_name(¶m.name); + let index = self.global_params.len(); + self.global_param_lookup.insert(param.name.clone(), index); + self.global_params.push(ir::GlobalParam { + name: param.name.clone(), + var_name, + required: param.required, + overrideable: true, + default_expr: None, + }); + } + ast::Declaration::Variable(variable) => { + if variable.static_ { + continue; + } + if self.global_param_lookup.contains_key(&variable.name) { + continue; + } + let var_name = self.variables.new_var_name(&variable.name); + let index = self.global_params.len(); + self.global_param_lookup.insert(variable.name.clone(), index); + self.global_params.push(ir::GlobalParam { + name: variable.name.clone(), + var_name, + required: false, + overrideable: false, + default_expr: None, + }); + } + _ => {} + } + } + Ok(()) + } + + fn collect_functions( + &mut self, + declarations: &[DeclarationWithImport], + ) -> error::SpannedResult<()> { + for declaration in declarations { + if let ast::Declaration::Function(function) = &declaration.declaration { + let arity = function.params.len(); + if arity > u8::MAX as usize { + return Err(error::Error::Unsupported( + "Function arity too large".to_string(), + ) + .into()); + } + let key = (function.name.clone(), arity as u8); + if self.function_lookup.contains_key(&key) { + return Err(error::Error::Unsupported( + "Duplicate function declaration".to_string(), + ) + .into()); + } + let index = self.function_lookup.len(); + self.function_lookup.insert(key, index); + } + } + if !self.function_lookup.is_empty() { + self.user_functions = Some(UserFunctions::new( + self.function_lookup.clone(), + self.global_param_names(), + )); + } + Ok(()) + } + + fn param_declaration(&mut self, param: &ast::Param) -> error::SpannedResult<()> { + if param.static_ { + return Ok(()); + } + if param.required && (param.select.is_some() || !param.sequence_constructor.is_empty()) { + return Err(error::SpannedError { + error: error::Error::XTSE0010, + span: Some((param.span.start..param.span.end).into()), + }); + } + let default_expr = if let Some(select) = ¶m.select { + Some(self.expression(select)?.expr()) + } else if !param.sequence_constructor.is_empty() { + if Self::sequence_constructor_has_content(¶m.sequence_constructor) { + Some( + self.sequence_constructor_document(¶m.sequence_constructor)? + .expr(), + ) + } else { + Some(self.sequence_constructor(¶m.sequence_constructor)?.expr()) + } + } else { + None + }; + if let Some(index) = self.global_param_lookup.get(¶m.name).copied() { + if let Some(entry) = self.global_params.get_mut(index) { + entry.required = param.required; + entry.default_expr = default_expr; + } + } + Ok(()) + } + + fn variable_declaration(&mut self, variable: &ast::Variable) -> error::SpannedResult<()> { + if variable.static_ { + return Ok(()); + } + let default_expr = if let Some(select) = &variable.select { + Some(self.expression(select)?.expr()) + } else if !variable.sequence_constructor.is_empty() { + if Self::sequence_constructor_has_content(&variable.sequence_constructor) { + Some( + self.sequence_constructor_document(&variable.sequence_constructor)? + .expr(), + ) + } else { + Some(self.sequence_constructor(&variable.sequence_constructor)?.expr()) + } + } else { + None + }; + if let Some(index) = self.global_param_lookup.get(&variable.name).copied() { + if let Some(entry) = self.global_params.get_mut(index) { + entry.default_expr = default_expr; + } + } + Ok(()) + } + + fn global_param_ir_params(&self) -> Vec { + self.global_params + .iter() + .map(|param| ir::Param { + name: param.var_name.clone(), + type_: None, + }) + .collect() + } + + fn global_param_names(&self) -> Vec { + self.global_params + .iter() + .map(|param| param.var_name.clone()) + .collect() + } } diff --git a/xee-xslt-compiler/src/lib.rs b/xee-xslt-compiler/src/lib.rs index 7cdabcb5f..87f48488a 100644 --- a/xee-xslt-compiler/src/lib.rs +++ b/xee-xslt-compiler/src/lib.rs @@ -3,5 +3,5 @@ mod default_declarations; mod priority; mod run; -pub use ast_ir::parse; +pub use ast_ir::{parse, parse_with_base}; pub use run::evaluate; diff --git a/xee-xslt-compiler/src/priority.rs b/xee-xslt-compiler/src/priority.rs index 96d92e185..47833f72e 100644 --- a/xee-xslt-compiler/src/priority.rs +++ b/xee-xslt-compiler/src/priority.rs @@ -10,6 +10,9 @@ type Pattern = pattern::Pattern; pub(crate) fn default_priority<'a>( pattern: &'a Pattern, ) -> Box, Decimal)> + 'a> { + if let Some(binary_expr) = top_level_binary_expr(pattern) { + return default_priority_top_level_binary(Cow::Borrowed(pattern), binary_expr); + } match pattern { pattern::Pattern::Predicate(predicate) => { if !predicate.predicates.is_empty() { @@ -28,6 +31,32 @@ pub(crate) fn default_priority<'a>( } } +fn top_level_binary_expr( + pattern: &Pattern, +) -> Option<&pattern::BinaryExpr> { + match pattern { + pattern::Pattern::Expr(pattern::ExprPattern::BinaryExpr(binary_expr)) => Some(binary_expr), + pattern::Pattern::Expr(pattern::ExprPattern::Path(path)) => { + path_top_level_binary_expr(path) + } + _ => None, + } +} + +fn path_top_level_binary_expr( + path: &pattern::PathExpr, +) -> Option<&pattern::BinaryExpr> { + match path.steps.as_slice() { + [pattern::StepExpr::PostfixExpr(postfix)] if postfix.predicates.is_empty() => { + match &postfix.expr { + pattern::ExprPattern::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + } + } + _ => None, + } +} + fn default_priority_top_level_binary<'a>( pattern: Cow<'a, Pattern>, binary_expr: &'a pattern::BinaryExpr, @@ -189,23 +218,37 @@ mod tests { v[0].1 } + fn flatten_union(pattern: &Pattern) -> Vec { + let mut out = Vec::new(); + flatten_union_inner(pattern, &mut out); + out + } + + fn flatten_union_inner(pattern: &Pattern, out: &mut Vec) { + if let Some(binary_expr) = top_level_binary_expr(pattern) { + if binary_expr.operator == pattern::Operator::Union { + let left = Pattern::Expr(binary_expr.left.as_ref().clone()); + flatten_union_inner(&left, out); + let right = Pattern::Expr(binary_expr.right.as_ref().clone()); + flatten_union_inner(&right, out); + return; + } + } + out.push(pattern.clone()); + } + #[test] fn test_2_top_level_union_is_multiple_patterns() { let pattern = parse("foo | bar"); - let (first_pattern, second_pattern) = match pattern.clone() { - Pattern::Expr(pattern::ExprPattern::BinaryExpr(binary_expr)) => ( - pattern::Pattern::Expr(binary_expr.left.as_ref().clone()), - pattern::Pattern::Expr(binary_expr.right.as_ref().clone()), - ), - _ => panic!("Expected binary expression"), - }; + let parts = flatten_union(&pattern); + assert_eq!(parts.len(), 2); let priorities = default_priority(&pattern).collect::>(); assert_eq!( priorities, vec![ - (Cow::Owned(first_pattern), dec!(0)), - (Cow::Owned(second_pattern), dec!(0)) + (Cow::Owned(parts[0].clone()), dec!(0)), + (Cow::Owned(parts[1].clone()), dec!(0)) ] ); } @@ -213,20 +256,15 @@ mod tests { #[test] fn test_2_top_level_union_is_multiple_patterns_different_priority() { let pattern = parse("(/) | bar"); - let (first_pattern, second_pattern) = match pattern.clone() { - Pattern::Expr(pattern::ExprPattern::BinaryExpr(binary_expr)) => ( - pattern::Pattern::Expr(binary_expr.left.as_ref().clone()), - pattern::Pattern::Expr(binary_expr.right.as_ref().clone()), - ), - _ => panic!("Expected binary expression"), - }; + let parts = flatten_union(&pattern); + assert_eq!(parts.len(), 2); let priorities = default_priority(&pattern).collect::>(); assert_eq!( priorities, vec![ - (Cow::Owned(first_pattern), dec!(-0.5)), - (Cow::Owned(second_pattern), dec!(0)) + (Cow::Owned(parts[0].clone()), dec!(-0.5)), + (Cow::Owned(parts[1].clone()), dec!(0)) ] ); } @@ -234,27 +272,16 @@ mod tests { #[test] fn test_2_top_level_union_more_unions() { let pattern = parse("foo | bar | baz"); - let ((first_pattern, second_pattern), third_pattern) = match pattern.clone() { - Pattern::Expr(pattern::ExprPattern::BinaryExpr(binary_expr)) => ( - match binary_expr.left.as_ref() { - pattern::ExprPattern::BinaryExpr(binary_expr) => ( - pattern::Pattern::Expr(binary_expr.left.as_ref().clone()), - pattern::Pattern::Expr(binary_expr.right.as_ref().clone()), - ), - _ => panic!("Expected binary expression"), - }, - pattern::Pattern::Expr(binary_expr.right.as_ref().clone()), - ), - _ => panic!("Expected binary expression"), - }; + let parts = flatten_union(&pattern); + assert_eq!(parts.len(), 3); let priorities = default_priority(&pattern).collect::>(); assert_eq!( priorities, vec![ - (Cow::Owned(first_pattern), dec!(0)), - (Cow::Owned(second_pattern), dec!(0)), - (Cow::Owned(third_pattern), dec!(0)) + (Cow::Owned(parts[0].clone()), dec!(0)), + (Cow::Owned(parts[1].clone()), dec!(0)), + (Cow::Owned(parts[2].clone()), dec!(0)) ] ); } diff --git a/xee-xslt-compiler/tests/test_xslt.rs b/xee-xslt-compiler/tests/test_xslt.rs index 946324eba..7a19b5f9b 100644 --- a/xee-xslt-compiler/tests/test_xslt.rs +++ b/xee-xslt-compiler/tests/test_xslt.rs @@ -1,7 +1,14 @@ use std::fmt::Write; -use xee_interpreter::{error, sequence::Sequence}; -use xee_xslt_compiler::evaluate; +use xee_interpreter::{ + context::{StaticContext, TypeTableRef}, + error, + sequence::Sequence, + xml::Documents, +}; +use xee_name::{Namespaces, FN_NAMESPACE}; +use xee_schema_type::Xs; +use xee_xslt_compiler::{evaluate, parse}; use xot::Xot; fn xml(xot: &Xot, sequence: Sequence) -> String { @@ -14,6 +21,44 @@ fn xml(xot: &Xot, sequence: Sequence) -> String { f } +fn evaluate_with_type_table( + xot: &mut Xot, + xml: &str, + xslt: &str, + type_table: TypeTableRef, +) -> error::SpannedResult { + let namespaces = Namespaces::new( + Namespaces::default_namespaces(), + "".to_string(), + FN_NAMESPACE.to_string(), + ); + let static_context = StaticContext::from_namespaces(namespaces); + let root = xot.parse(xml).unwrap(); + let program = parse(static_context, xslt).unwrap(); + let mut documents = Documents::new(); + let handle = documents.add_root(None, root).unwrap(); + let root = documents.get_node_by_handle(handle).unwrap(); + let mut dynamic_context_builder = program.dynamic_context_builder(); + dynamic_context_builder.context_node(root); + dynamic_context_builder.documents(documents); + dynamic_context_builder.type_table(type_table); + let context = dynamic_context_builder.build(); + let runnable = program.runnable(&context); + runnable.many(xot) +} + +fn child_element_named(xot: &Xot, parent: xot::Node, name: &str) -> xot::Node { + let target = xot.name(name).unwrap(); + let mut child = xot.first_child(parent); + while let Some(node) = child { + if xot.is_element(node) && xot.node_name(node) == Some(target) { + return node; + } + child = xot.next_sibling(node); + } + panic!("child element not found: {}", name); +} + #[test] fn test_transform() { let mut xot = Xot::new(); @@ -1230,3 +1275,88 @@ fn test_basic_iterate_params() { "124" ); } + +#[test] +fn test_try_catch_rollback_output() { + let mut xot = Xot::new(); + let output = evaluate( + &mut xot, + "", + r#" + + + + + + + + + + +"#, + ) + .unwrap(); + assert_eq!(xml(&xot, output), ""); +} + +#[test] +fn test_type_table_for_typed_constructors() { + let mut xot = Xot::new(); + let type_table = TypeTableRef::new(); + let output = evaluate_with_type_table( + &mut xot, + "", + r#" + + + + + + + + + + + + + + +"#, + type_table.clone(), + ) + .unwrap(); + + let root_name = xot.name("root").unwrap(); + let mut root_node = None; + let mut document_node = None; + for item in output.iter() { + let Ok(node) = item.to_node() else { + continue; + }; + if xot.is_element(node) && xot.node_name(node) == Some(root_name) { + root_node = Some(node); + } else if xot.is_document(node) { + document_node = Some(node); + } + } + + let root_node = root_node.expect("root element missing from output"); + let document_node = document_node.expect("document node missing from output"); + let typed_element = child_element_named(&xot, root_node, "typed-element"); + let attr_holder = child_element_named(&xot, root_node, "attr-holder"); + let copy_node = child_element_named(&xot, root_node, "src-copy"); + let copy_of_node = child_element_named(&xot, root_node, "src-copy-of"); + let typed_attr = xot + .attributes(attr_holder) + .get_node(xot.name("typed-attr").unwrap()) + .unwrap(); + + let type_table = type_table.borrow(); + assert_eq!(type_table.get(typed_element), Some(Xs::String)); + assert_eq!(type_table.get(typed_attr), Some(Xs::Integer)); + assert_eq!(type_table.get(copy_node), Some(Xs::Decimal)); + assert_eq!(type_table.get(copy_of_node), Some(Xs::Boolean)); + assert_eq!(type_table.get(document_node), Some(Xs::AnyType)); +} diff --git a/xslt-plan.md b/xslt-plan.md index c1e6bf747..3fae78bc6 100644 --- a/xslt-plan.md +++ b/xslt-plan.md @@ -5,7 +5,7 @@ are, and how people could contribute. ## Current status -`xee-xslt-ast` parses the XSTL stylesheets into an AST. This AST is +`xee-xslt-ast` parses the XSLT stylesheets into an AST. This AST is very similar in structure to the underlying XML. `xee-xslt-compiler` compiles this AST to IR (as defined by `xee-ir`). @@ -49,16 +49,16 @@ This won't make any more XSLT run but it ensures what we load is correct. ### xee-testrunner -We want `xee-testrunner` to be able to execute the XSLT test suite in -`vendor/xslt`. Then we can slowly build up test coverage. Martijn has done -preparatory work and is working towards the ability to run our first XSLT -conformance tests. +`xee-testrunner` can execute the XSLT test suite in `vendor/xslt-tests`, but +coverage is still partial and many tests are filtered. The focus now is +expanding coverage and improving runner support alongside new XSLT +functionality. ### xee-xslt-compiler We want to extend `xee-xslt-compiler` so it can compile more XSLT constructs to the IR. The compilation code is in `src/ast_ir.rs`. We can extend the tests in -`test/test_xslt.rs` but the test runner once it works can also help drive this. +`tests/test_xslt.rs` but the test runner can also help drive this. Right now we don't have snapshot tests to verify that particular AST gets transformed into particular IR, but it may be useful to add this.