diff --git a/src/main/scala/org/geneontology/rules/engine/ReteNodes.scala b/src/main/scala/org/geneontology/rules/engine/ReteNodes.scala index 1cb2092..6344c0e 100644 --- a/src/main/scala/org/geneontology/rules/engine/ReteNodes.scala +++ b/src/main/scala/org/geneontology/rules/engine/ReteNodes.scala @@ -2,9 +2,6 @@ package org.geneontology.rules.engine import scala.collection.mutable -import scalaz._ -import scalaz.Scalaz._ - final class AlphaNode(val pattern: TriplePattern) { var children: List[JoinNode] = Nil @@ -18,10 +15,9 @@ final class AlphaNode(val pattern: TriplePattern) { def activate(triple: Triple, memory: WorkingMemory): Unit = { val alphaMem = memory.alpha.getOrElseUpdate(pattern, new AlphaMemory(pattern)) alphaMem.triples = triple :: alphaMem.triples - alphaMem.tripleIndexS = alphaMem.tripleIndexS |+| Map(triple.s -> Set(triple)) - alphaMem.tripleIndexP = alphaMem.tripleIndexP |+| Map(triple.p -> Set(triple)) - alphaMem.tripleIndexO = alphaMem.tripleIndexO |+| Map(triple.o -> Set(triple)) - //children.foreach(_.rightActivate(triple, memory)) + alphaMem.tripleIndexS += triple.s -> (triple :: alphaMem.tripleIndexS.getOrElse(triple.s, Nil)) + alphaMem.tripleIndexP += triple.p -> (triple :: alphaMem.tripleIndexP.getOrElse(triple.p, Nil)) + alphaMem.tripleIndexO += triple.o -> (triple :: alphaMem.tripleIndexO.getOrElse(triple.o, Nil)) alphaMem.linkedChildren.foreach(_.rightActivate(triple, memory)) } @@ -29,7 +25,7 @@ final class AlphaNode(val pattern: TriplePattern) { sealed trait BetaNode { - def spec: List[TriplePattern] + def spec: JoinNodeSpec def addChild(node: BetaNode): Unit @@ -47,7 +43,7 @@ final object BetaRoot extends BetaNode with BetaParent { def leftActivate(token: Token, memory: WorkingMemory): Unit = () def addChild(node: BetaNode): Unit = () - val spec: List[TriplePattern] = Nil + val spec: JoinNodeSpec = JoinNodeSpec(Nil) val memory: BetaMemory = new BetaMemory(spec, Nil) val children = Nil memory.tokens = Token(Map.empty, Nil) :: memory.tokens @@ -71,10 +67,10 @@ final case class Token(bindings: Map[Variable, ConcreteNode], triples: List[Trip } -final class JoinNode(val leftParent: BetaNode with BetaParent, rightParent: AlphaNode, val spec: List[TriplePattern]) extends BetaNode with BetaParent { +final class JoinNode(val leftParent: BetaNode with BetaParent, rightParent: AlphaNode, val spec: JoinNodeSpec) extends BetaNode with BetaParent { - private val thisPattern = spec.head - private val parentBoundVariables = spec.drop(1).flatMap(_.variables).toSet + private val thisPattern = spec.pattern.head + private val parentBoundVariables = spec.pattern.drop(1).flatMap(_.variables).toSet private val thisPatternVariables = thisPattern.variables private val matchVariables = parentBoundVariables intersect thisPatternVariables private val rightParentPattern = rightParent.pattern @@ -94,7 +90,7 @@ final class JoinNode(val leftParent: BetaNode with BetaParent, rightParent: Alph betaMem.checkLeftLink = true } var valid = true - var possibleTriples: List[Set[Triple]] = Nil + var possibleTriples: List[List[Triple]] = Nil if (thisPattern.s.isInstanceOf[Variable]) { val v = thisPattern.s.asInstanceOf[Variable] if (parentBoundVariables(v)) { @@ -141,8 +137,6 @@ final class JoinNode(val leftParent: BetaNode with BetaParent, rightParent: Alph _ = tokensToSend = newToken :: tokensToSend (bindingVar, bindingValue) <- newToken.bindings } { - //betaMem.tokenIndex.getOrElseUpdate(binding, mutable.Set.empty).add(newToken) - //betaMem.tokenIndex.getOrElseUpdate(bindingVar, mutable.AnyRefMap.empty).getOrElseUpdate(bindingValue, mutable.Set.empty).add(newToken) val currentMap = betaMem.tokenIndex.getOrElseUpdate(bindingVar, mutable.AnyRefMap.empty) val currentList = currentMap.getOrElse(bindingValue, Nil) currentMap(bindingValue) = newToken :: currentList @@ -224,7 +218,6 @@ final class ProductionNode(rule: Rule, parent: BetaNode, engine: RuleEngine) ext for { pattern <- rule.head } { - //FIXME get rid of casting val newTriple = Triple( produceNode(pattern.s, token).asInstanceOf[Resource], produceNode(pattern.p, token).asInstanceOf[URI], @@ -237,11 +230,17 @@ final class ProductionNode(rule: Rule, parent: BetaNode, engine: RuleEngine) ext private def produceNode(node: Node, token: Token): ConcreteNode = node match { case c: ConcreteNode => c case v: Variable => token.bindings(v) - //case AnyNode => error + case AnyNode => throw new RuntimeException("Invalid rule head containing AnyNode") } def addChild(node: BetaNode): Unit = () - val spec: List[TriplePattern] = Nil + val spec: JoinNodeSpec = JoinNodeSpec(Nil) } + +final case class JoinNodeSpec(pattern: List[TriplePattern]) { + + override val hashCode: Int = pattern.hashCode + +} \ No newline at end of file diff --git a/src/main/scala/org/geneontology/rules/engine/RuleEngine.scala b/src/main/scala/org/geneontology/rules/engine/RuleEngine.scala index 6c0f794..dfb6597 100644 --- a/src/main/scala/org/geneontology/rules/engine/RuleEngine.scala +++ b/src/main/scala/org/geneontology/rules/engine/RuleEngine.scala @@ -20,7 +20,7 @@ final class RuleEngine(inputRules: Iterable[Rule], val storeDerivations: Boolean val blankPattern = pattern.blankVariables val alphaNode = alphaNodeIndex.getOrElseUpdate(blankPattern, new AlphaNode(blankPattern)) val thisPatternSequence = pattern :: parentPatterns - val joinNode = joinIndex.getOrElseUpdate(thisPatternSequence, new JoinNode(parent, alphaNode, thisPatternSequence)) + val joinNode = joinIndex.getOrElseUpdate(thisPatternSequence, new JoinNode(parent, alphaNode, JoinNodeSpec(thisPatternSequence))) parent.addChild(joinNode) alphaNode.addChild(joinNode) if (parent == BetaRoot) topJoinNodes += joinNode @@ -65,24 +65,21 @@ final class RuleEngine(inputRules: Iterable[Rule], val storeDerivations: Boolean private val DegeneratePattern = TriplePattern(AnyNode, AnyNode, AnyNode) protected[engine] def processTriple(triple: Triple, memory: WorkingMemory): Unit = { - if (!memory.facts(triple)) { - memory.facts += triple + if (memory.facts.add(triple)) { memory.agenda = memory.agenda.enqueue(triple) } - } protected[engine] def processDerivedTriple(triple: Triple, derivation: Derivation, memory: WorkingMemory) = { - if (!memory.facts(triple)) { - memory.facts += triple - //if (memory.facts.size % 100000 == 0) println(memory.facts.size) - memory.derivations = memory.derivations |+| Map(triple -> List(derivation)) + if (memory.facts.add(triple)) { + memory.derivations += triple -> (derivation :: memory.derivations.getOrElse(triple, Nil)) memory.agenda = memory.agenda.enqueue(triple) } } private def injectTriple(triple: Triple, memory: WorkingMemory): Unit = { - val patterns = List(DegeneratePattern, + val patterns = List( + DegeneratePattern, TriplePattern(AnyNode, AnyNode, triple.o), TriplePattern(AnyNode, triple.p, AnyNode), TriplePattern(AnyNode, triple.p, triple.o), diff --git a/src/main/scala/org/geneontology/rules/engine/WorkingMemory.scala b/src/main/scala/org/geneontology/rules/engine/WorkingMemory.scala index 664ac74..3794cc1 100644 --- a/src/main/scala/org/geneontology/rules/engine/WorkingMemory.scala +++ b/src/main/scala/org/geneontology/rules/engine/WorkingMemory.scala @@ -4,14 +4,14 @@ import scala.collection.immutable.Queue import scala.collection.mutable import scala.collection.mutable.AnyRefMap -final class WorkingMemory(var asserted: Set[Triple]) { +final class WorkingMemory(val asserted: Set[Triple]) { var agenda: Queue[Triple] = Queue.empty - var facts: Set[Triple] = asserted + val facts: mutable.Set[Triple] = mutable.Set.empty ++ asserted var derivations: Map[Triple, List[Derivation]] = Map.empty val alpha: mutable.Map[TriplePattern, AlphaMemory] = AnyRefMap.empty - val beta: mutable.Map[List[TriplePattern], BetaMemory] = AnyRefMap.empty + val beta: mutable.Map[JoinNodeSpec, BetaMemory] = AnyRefMap.empty beta += (BetaRoot.spec -> BetaRoot.memory) def explain(triple: Triple): Set[Explanation] = explainAll(Set(triple)) @@ -48,19 +48,18 @@ final class WorkingMemory(var asserted: Set[Triple]) { final class AlphaMemory(pattern: TriplePattern) { var triples: List[Triple] = Nil - var tripleIndexS: Map[ConcreteNode, Set[Triple]] = Map.empty - var tripleIndexP: Map[ConcreteNode, Set[Triple]] = Map.empty - var tripleIndexO: Map[ConcreteNode, Set[Triple]] = Map.empty + var tripleIndexS: Map[ConcreteNode, List[Triple]] = Map.empty + var tripleIndexP: Map[ConcreteNode, List[Triple]] = Map.empty + var tripleIndexO: Map[ConcreteNode, List[Triple]] = Map.empty var linkedChildren: List[JoinNode] = Nil } -final class BetaMemory(val spec: List[TriplePattern], initialLinkedChildren: List[BetaNode]) { +final class BetaMemory(val spec: JoinNodeSpec, initialLinkedChildren: List[BetaNode]) { var tokens: List[Token] = Nil var checkRightLink: Boolean = true var checkLeftLink: Boolean = false - //val tokenIndex: mutable.Map[(Variable, ConcreteNode), mutable.Set[Token]] = AnyRefMap.empty val tokenIndex: mutable.Map[Variable, mutable.Map[ConcreteNode, List[Token]]] = AnyRefMap.empty var linkedChildren: List[BetaNode] = initialLinkedChildren