From 81e666cce7c513dc3a73d9155030de822b41f388 Mon Sep 17 00:00:00 2001 From: Kevin Lano Date: Wed, 14 Aug 2024 14:55:17 +0100 Subject: [PATCH] Improved type inference --- BSystemTypes.java | 12 +-- BinaryExpression.java | 190 +++++++++++++++++++++++++++++++++++++----- Compiler2.java | 14 +++- SetExpression.java | 81 +++++++++++++++++- Statement.java | 13 ++- Type.java | 62 +++++++++++++- 6 files changed, 336 insertions(+), 36 deletions(-) diff --git a/BSystemTypes.java b/BSystemTypes.java index fa753e8c..5abe3555 100644 --- a/BSystemTypes.java +++ b/BSystemTypes.java @@ -176,7 +176,7 @@ else if (ename.equals("boolean")) res = res + " return _results_" + oldindex + ";\n }\n\n"; // Version for maps: - res = res + " public static Map select_" + oldindex + "(Map _l"; + res = res + " public static HashMap select_" + oldindex + "(Map _l"; for (int i = 0; i < pars.size(); i++) { Attribute par = (Attribute) pars.get(i); @@ -186,7 +186,7 @@ else if (ename.equals("boolean")) } res = res + ")\n"; res = res + " { // Implements: " + left + "->select(" + var + " | " + pred + ")\n" + - " Map _results_" + oldindex + " = new java.util.HashMap();\n" + + " HashMap _results_" + oldindex + " = new HashMap();\n" + " java.util.Set _keys = _l.keySet();\n" + " for (Object _i : _keys)\n"; if (ename.equals("int") || "Integer".equals(tname)) @@ -773,7 +773,7 @@ else if (ename.equals("boolean")) res = res + " return _results_" + oldindex + ";\n }\n\n"; // Also need a Map version: - res = res + " public static Map collect_" + oldindex + "(Map _l"; + res = res + " public static HashMap collect_" + oldindex + "(Map _l"; for (int i = 0; i < pars.size(); i++) { Attribute par = (Attribute) pars.get(i); res = res + "," + par.getType().getJava() + " " + par.getName(); @@ -781,7 +781,7 @@ else if (ename.equals("boolean")) res = res + ")\n"; res = res + " { // implements: " + left + "->collect( " + var + " | " + exp + " )\n" + - " Map _results_" + oldindex + " = new HashMap();\n" + + " HashMap _results_" + oldindex + " = new HashMap();\n" + " java.util.Set _keys = _l.keySet();\n" + " for (Object _i : _keys)\n" + " { " + tname + " " + var + " = (" + tname + ") _l.get(_i);\n" + @@ -1419,7 +1419,7 @@ else if (ename.equals("boolean")) /* Version for maps: */ - res = res + " public static Map reject_" + oldindex + "(Map _l"; + res = res + " public static HashMap reject_" + oldindex + "(Map _l"; for (int i = 0; i < pars.size(); i++) { Attribute par = (Attribute) pars.get(i); @@ -1429,7 +1429,7 @@ else if (ename.equals("boolean")) } res = res + ")\n"; res = res + " { // Implements: " + left + "->reject(" + var + " | " + pred + ")\n" + - " Map _results_" + oldindex + " = new java.util.HashMap();\n" + + " HashMap _results_" + oldindex + " = new HashMap();\n" + " java.util.Set _keys = _l.keySet();\n" + " for (Object _i : _keys)\n"; if (ename.equals("int") || "Integer".equals(tname)) diff --git a/BinaryExpression.java b/BinaryExpression.java index 4ec0234c..26c86c2f 100644 --- a/BinaryExpression.java +++ b/BinaryExpression.java @@ -3877,8 +3877,35 @@ else if ("|sortedBy".equals(operator)) elementType = scope.elementType; multiplicity = ModelElement.MANY; } - else if (operator.equals("->collect") || - operator.equals("->unionAll") || + else if (operator.equals("->collect")) + { if (left.isCollection() || left.isMap()) { } + else + { System.err.println("!! Left argument of " + operator + + " must be a collection or map"); + left.setType(new Type("Sequence", null)); + + if (left instanceof BasicExpression) + { String vname = ((BasicExpression) left).basicString(); + vartypes.put(vname, left.getType()); + } + } + + Type tl = (Type) vartypes.get(left + ""); + + if (left.isMap() || Type.isMapType(tl)) + { elementType = right.getType(); + type = new Type("Map", null); + type.setElementType(elementType); + if (tl == null) { tl = tleft; } + type.setKeyType(tleft.getKeyType()); + } + else + { elementType = right.getType(); + type = new Type("Sequence", null); + type.setElementType(elementType); + } // NOT sorted + } + else if (operator.equals("->unionAll") || operator.equals("->intersectAll") || operator.equals("->concatenateAll") || operator.equals("->any")) @@ -3910,8 +3937,78 @@ else if (operator.equals("->collect") || type = left.elementType; } } - else if (operator.equals("|C") || - "|unionAll".equals(operator) || + else if (operator.equals("|C")) + { BinaryExpression lexp = (BinaryExpression) left; + Expression scope = lexp.right; + Expression vbl = lexp.left; + + Type vblType = (Type) vartypes.get(vbl + ""); + + Type scopetype = (Type) vartypes.get(scope + ""); + + if (scope.isCollection() || scope.isMap()) + { if (Type.isVacuousType(scope.elementType)) + { System.err.println("!! No element type for " + scope); + Type tt = (Type) vartypes.get(scope + ""); + if (tt != null && + !Type.isVacuousType(tt.elementType)) + { scope.setElementType(tt.elementType); + vbl.setType(tt.elementType); + } + else if (vblType != null) + { scope.setElementType(vblType); + vbl.setType(vblType); + } + + System.out.println(">> Set " + scope + " element type to " + scope.getElementType()); + } + } + else if (Type.isMapType(scopetype) || + Type.isCollectionType(scopetype)) + { if (Type.isVacuousType(scope.elementType)) + { System.err.println("!! No element type for " + scope); + if (scopetype != null && + !Type.isVacuousType(scopetype.elementType)) + { scope.setElementType(scopetype.elementType); + vbl.setType(scopetype.elementType); + } + else if (vblType != null) + { scope.setElementType(vblType); + vbl.setType(vblType); + } + + scope.setType(scopetype); + + System.out.println(">> Set " + this + " element type to " + scope.getElementType()); + } + } + else + { System.err.println("!! Left argument of " + operator + + " must be a collection"); + scope.setType(new Type("Sequence", null)); + + if (scope instanceof BasicExpression) + { String vname = ((BasicExpression) scope).basicString(); + vartypes.put(vname, scope.getType()); + } + } + + if (scope.isMap() || Type.isMapType(scopetype)) + { elementType = right.getType(); + type = new Type("Map", null); + type.setElementType(elementType); + if (scopetype == null) { scopetype = scope.getType(); } + type.setKeyType(scopetype.getKeyType()); + + // JOptionPane.showInputDialog(">> Type of " + this + " is " + type); + } + else + { elementType = right.getType(); + type = new Type("Sequence", null); + type.setElementType(elementType); + } // NOT sorted + } + else if ("|unionAll".equals(operator) || "|intersectAll".equals(operator) || "|concatenateAll".equals(operator)) { BinaryExpression lexp = (BinaryExpression) left; @@ -4314,7 +4411,7 @@ else if ("->restrict".equals(operator) || { if (tleft != null) { type = tleft; elementType = left.elementType; - } + } // sorted if left is sorted else { type = new Type("Map", null); } @@ -4556,7 +4653,45 @@ else if (operator.equals("\\/") || Type lftype = left.getType(); Type rtype = right.getType(); - if (left.isMap() && !right.isMap()) + Type letype = left.getElementType(); + Type retype = right.getElementType(); + + if (left.isMap() && right.isMap()) + { if (operator.equals("\\/") || + operator.equals("->union") || + operator.equals("->symmetricDifference") || + operator.equals("->intersection") || + operator.equals("/\\")) + { Vector etypes = new Vector(); + etypes.add(left); + etypes.add(right); + elementType = Type.determineElementType(etypes); + type = new Type("Map", null); + type.setSorted(lftype.isSorted()); + type.setKeyType(lftype.getKeyType()); + type.setElementType(elementType); + } + else + { type = new Type("boolean", null); } + } + else if (left.isCollection() && right.isCollection()) + { if (operator.equals("\\/") || + operator.equals("->union") || + operator.equals("->symmetricDifference") || + operator.equals("->intersection") || + operator.equals("/\\")) + { Vector etypes = new Vector(); + etypes.add(left); + etypes.add(right); + elementType = Type.determineElementType(etypes); + type = new Type(lftype.getName(), null); + type.setSorted(lftype.isSorted()); + type.setElementType(elementType); + } + else + { type = new Type("boolean", null); } + } + else if (left.isMap() && !right.isMap()) { System.err.println("!! RHS of " + this + " must be map"); @@ -4566,10 +4701,10 @@ else if (operator.equals("\\/") || right.setType(rtype); if (operator.equals("\\/") || - operator.equals("->union") || - operator.equals("->symmetricDifference") || - operator.equals("->intersection") || - operator.equals("/\\")) + operator.equals("->union") || + operator.equals("->symmetricDifference") || + operator.equals("->intersection") || + operator.equals("/\\")) { type = new Type("Map", null); elementType = left.elementType; type.setKeyType(lftype.getKeyType()); @@ -4658,9 +4793,10 @@ else if (!left.isCollection() && right.isCollection()) vartypes.put(vname, left.getType()); } } - else + else if (!left.isCollection() && + !left.isMap()) { System.err.println("!! Arguments of " + this + - " must be collections"); + " must be collections/maps"); Type ltype = left.getType(); if (ltype.getAlias() != null && @@ -5586,8 +5722,8 @@ else if (operator.equals("->includes") || { } else { System.err.println("!! TYPE ERROR: LHS of " + this + " must be a collection"); - JOptionPane.showMessageDialog(null, "LHS of " + this + " must be a collection!", - "Type error", JOptionPane.ERROR_MESSAGE); + // JOptionPane.showMessageDialog(null, "LHS of " + this + " must be a collection!", + // "Type error", JOptionPane.ERROR_MESSAGE); } // deduce type of one side from that of other if (tright == null && tleft != null) @@ -6081,23 +6217,26 @@ private void tcCollect(Type tleft, Type tright, Entity eright) restype.keyType = tleft.keyType; restype.elementType = tright; type = restype; - elementType = tright; + elementType = tright; + + // JOptionPane.showInputDialog("Type of " + this + + // " is " + type); return; } else if (collectleft.isMultiple()) { } else { System.err.println("!!! TYPE ERROR: LHS of collect must be a collection! " + this); - JOptionPane.showMessageDialog(null, "LHS must be a collection: " + this, - "Type error", JOptionPane.ERROR_MESSAGE); + // JOptionPane.showMessageDialog(null, "LHS must be a collection: " + this, + // "Type error", JOptionPane.ERROR_MESSAGE); // type = null; // return; } if (tright == null) { System.err.println("!!! TYPE ERROR: No type for collect RHS: " + this); - JOptionPane.showMessageDialog(null, "ERROR: No type for collect RHS: " + this, - "Type error", JOptionPane.ERROR_MESSAGE); + // JOptionPane.showMessageDialog(null, "ERROR: No type for collect RHS: " + this, + // "Type error", JOptionPane.ERROR_MESSAGE); return; } @@ -6155,8 +6294,8 @@ private void tcMathOps(Type tleft, Type tright, System.err.println("! Warning!: arguments must be numeric in: " + this + " Deduced type: " + type); if (type == null) - { JOptionPane.showMessageDialog(null, "Arguments not numeric in: " + this, - "Type error", JOptionPane.ERROR_MESSAGE); + { // JOptionPane.showMessageDialog(null, "Arguments not numeric in: " + this, + // "Type error", JOptionPane.ERROR_MESSAGE); } } @@ -20088,6 +20227,15 @@ public Expression simplifyOCL() Expression lexpr = left.simplifyOCL(); Expression rexpr = right.simplifyOCL(); + if ("->union".equals(operator) && + lexpr instanceof SetExpression && + rexpr instanceof SetExpression) + { // merge the literal sets/sequences/maps + SetExpression res = + SetExpression.mergeSetExpressions((SetExpression) lexpr, + (SetExpression) rexpr); + return res; + } if (operator.equals("|")) { BinaryExpression arg = (BinaryExpression) left; diff --git a/Compiler2.java b/Compiler2.java index b274e9c3..f4344425 100644 --- a/Compiler2.java +++ b/Compiler2.java @@ -11116,11 +11116,23 @@ public static void main(String[] args) // c.nospacelexicalanalysis("(a[i][j]).f(1)"); // c.nospacelexicalanalysis("(!a).f(1)"); - c.nospacelexicalanalysis("(OclFile[\"SYSOUT\"]).println(x)"); + // c.nospacelexicalanalysis("(OclFile[\"SYSOUT\"]).println(x)"); + + c.nospacelexicalanalysis("Map{ \"Name\" |-> Sequence{\"Braund, Mr. Owen Harris\"}->union(Sequence{\"Allen, Mr. William Henry\"}->union(Sequence{ \"Bonnell, Miss. Elizabeth\" })) }->union(Map{ \"Age\" |-> Sequence{22}->union(Sequence{35}->union(Sequence{ 58 })) }->union(Map{ \"Sex\" |-> Sequence{\"male\"}->union(Sequence{\"male\"}->union(Sequence{ \"female\" })) }->union(Map{ \"Fare\" |-> Sequence{102.0}->union(Sequence{99.0}->union(Sequence{ 250.0 })) }) ) )"); + Expression zz = c.parseExpression(); System.out.println(zz); + zz.typeCheck(new Vector(), new Vector(), new Vector(), new Vector()); + + Expression pp = zz.simplifyOCL(); + + System.out.println(pp); + + pp.typeCheck(new Vector(), new Vector(), new Vector(), new Vector()); + + System.out.println(">>> " + pp.getType()); // Compiler2 ccx = new Compiler2(); // ccx.nospacelexicalanalysis("x : int"); diff --git a/SetExpression.java b/SetExpression.java index 47e8ac6e..47a67396 100644 --- a/SetExpression.java +++ b/SetExpression.java @@ -25,7 +25,7 @@ public SetExpression(boolean b) { type = new Type("Sequence", null); } else { type = new Type("Set", null); } - } + } // what about maps? public SetExpression(Vector v) { if (v == null || v.size() == 0 || @@ -72,6 +72,80 @@ public static SetExpression newRefSetExpression(Expression elem) return res; } + public SetExpression(Vector elems, Type typ) + { elements = elems; + type = (Type) typ.clone(); + if (Type.isSequenceType(type)) + { ordered = true; } + + elementType = Type.determineType(elements); + type.setElementType(elementType); + } + + public static SetExpression mergeSetExpressions( + SetExpression left, + SetExpression right) + { // ->union of two literal collections, maps + + Type typ = left.getType(); + Vector elems1 = left.getElements(); + Vector elems2 = right.getElements(); + Vector newelems = new Vector(); + newelems.addAll(elems1); + + if (Type.isSequenceType(typ)) + { newelems.addAll(elems2); + SetExpression res = new SetExpression(newelems,typ); + return res; + } + + if (Type.isSetType(typ)) + { for (int i = 0; i < elems2.size(); i++) + { Expression e2 = (Expression) elems2.get(i); + if (VectorUtil.containsEqualString( + e2 + "", newelems)) + { } + else + { newelems.add(e2); } + } + SetExpression res = new SetExpression(newelems,typ); + return res; + } + + // Else - maps + + System.out.println("*** Merging maps " + left + " and " + right); + + Vector mapelems = new Vector(); + for (int i = 0; i < elems1.size(); i++) + { BinaryExpression maplet1 = + (BinaryExpression) elems1.get(i); + Expression key1 = maplet1.getLeft(); + + System.out.println("*** KEY 1: " + key1); + + boolean foundkey1 = false; + for (int j = 0; j < elems2.size(); j++) + { BinaryExpression maplet2 = + (BinaryExpression) elems2.get(j); + Expression key2 = maplet2.getLeft(); + + if ((key1 + "").equals(key2 + "")) + { // maplet2 overrides maplet1 + foundkey1 = true; + break; + } // don't include maplet1 in mapelems + } + + if (!foundkey1) + { mapelems.add(maplet1); } + } + mapelems.addAll(elems2); + + SetExpression res = new SetExpression(mapelems,typ); + return res; + } + public Vector getParameters() { return new Vector(); } @@ -932,8 +1006,9 @@ public boolean typeInference(final Vector typs, java.util.Map vartypes) { return typeCheck(typs,ents,contexts,env); } - public boolean typeCheck(final Vector types, final Vector entities, - final Vector contexts, final Vector env) + public boolean typeCheck(final Vector types, + final Vector entities, + final Vector contexts, final Vector env) { boolean res = true; if (type != null && "Ref".equals(type.getName())) diff --git a/Statement.java b/Statement.java index 11296da6..f1371ef6 100644 --- a/Statement.java +++ b/Statement.java @@ -7358,6 +7358,9 @@ else if (typ != null) if (initialExpression != null) { initialExpression.typeInference(types,entities, ctxs,env,vartypes); + + // JOptionPane.showInputDialog("--- " + vartypes + " -- Type inference for: " + initialExpression.getType() + " for " + initialExpression); + System.out.println(">>> Inferred type " + initialExpression.getType() + "(" + initialExpression.getElementType() + @@ -12781,10 +12784,12 @@ else if (BasicExpression.isMapAccess(lhs)) public boolean typeInference(Vector types, Vector entities, Vector cs, Vector env, java.util.Map vartypes) { // Also recognise the type as an entity or enumeration if it exists - boolean res = lhs.typeCheck(types,entities,cs,env); - res = rhs.typeCheck(types,entities,cs,env); + // boolean res = lhs.typeCheck(types,entities,cs,env); + // res = rhs.typeCheck(types,entities,cs,env); + boolean res = rhs.typeInference(types,entities,cs,env,vartypes); Type rhsType = rhs.getType(); - res = rhs.typeInference(types,entities,cs,env,vartypes); + + vartypes.put(lhs + "", rhsType); if (Type.isVacuousType(lhs.type) && !Type.isVacuousType(rhsType)) @@ -12816,6 +12821,8 @@ else if (BasicExpression.isMapAccess(lhs)) System.out.println(">>> " + lhs + " actual type is " + declaredType); } // does not allow for changing the type + // JOptionPane.showInputDialog("--- deduced type " + rhsType + " for " + lhs); + return res; } diff --git a/Type.java b/Type.java index 643df75e..2b4446b8 100644 --- a/Type.java +++ b/Type.java @@ -166,6 +166,12 @@ public void setGenericTypeParameters(Vector pars) { entity.setTypeParameters(pars); } } + public boolean equals(Object other) + { if (other instanceof Type) + { return ("" + this).equals(other + ""); } + return false; + } + public static boolean isOclLibraryType(String tname) { if (tname == null) { return false; } @@ -4832,6 +4838,34 @@ public static Type getTypeFor(String typ, Vector types, Vector entities) } } + if (typ.startsWith("SortedMap(String,") && typ.endsWith(")")) + { String nt = typ.substring(17,typ.length()-1); + Type innerT = getTypeFor(nt, types, entities); + Type resT = new Type("Map",null); + resT.setSorted(true); + resT.setKeyType(new Type("String", null)); + resT.setElementType(innerT); + return resT; + } + + if (typ.startsWith("SortedMap(") && typ.endsWith(")")) + { for (int i = 11; i < typ.length(); i++) + { if (",".equals(typ.charAt(i) + "")) + { String nt = typ.substring(11,i); + Type innerT = getTypeFor(nt, types, entities); + String rt = typ.substring(i+1,typ.length()-1); + Type restT = getTypeFor(rt, types, entities); + if (innerT != null && restT != null) + { Type resT = new Type("Map",null); + resT.setSorted(true); + resT.setKeyType(innerT); + resT.setElementType(restT); + return resT; + } + } + } + } + if (typ.startsWith("Function(String,") && typ.endsWith(")")) { String nt = typ.substring(16,typ.length()-1); Type innerT = getTypeFor(nt, types, entities); @@ -4888,6 +4922,12 @@ else if (tn2.equals("double") && { expectedType = t; } else if (tn2.equals("long") && tn1.equals("int")) { expectedType = t; } + else if (Type.isSequenceType(t) && + Type.isSequenceType(expectedType)) + { expectedType = new Type("Sequence", null); } + else if (Type.isSetType(t) && + Type.isSetType(expectedType)) + { expectedType = new Type("Set", null); } else if (tn1.equals(tn2)) { } // both maps, both sequences, both sets else @@ -4933,6 +4973,12 @@ else if (tn2.equals("double") && (tn1.equals("int") || tn1.equals("long"))) { expectedType = t; } else if (tn2.equals("long") && tn1.equals("int")) { expectedType = t; } + else if (Type.isSequenceType(t) && + Type.isSequenceType(expectedType)) + { expectedType = new Type("Sequence", null); } + else if (Type.isSetType(t) && + Type.isSetType(expectedType)) + { expectedType = new Type("Set", null); } else if (tn1.equals(tn2)) { } else { Entity e1 = expectedType.getEntity(); @@ -5014,14 +5060,25 @@ else if (expectedType.equals(t)) { } else { String tn1 = expectedType.getName(); String tn2 = t.getName(); - if (tn1.equals("double") && (tn2.equals("int") || tn2.equals("long"))) + + if (tn1.equals("double") && + (tn2.equals("int") || tn2.equals("long"))) { } else if (tn1.equals("long") && tn2.equals("int")) { } - else if (tn2.equals("double") && (tn1.equals("int") || tn1.equals("long"))) + else if (tn2.equals("double") && + (tn1.equals("int") || tn1.equals("long"))) { expectedType = t; } else if (tn2.equals("long") && tn1.equals("int")) { expectedType = t; } + else if (Type.isSequenceType(t) && + Type.isSequenceType(expectedType)) + { expectedType = new Type("Sequence", null); + JOptionPane.showInputDialog("** Deduced element type " + expectedType + " for " + elems); + } + else if (Type.isSetType(t) && + Type.isSetType(expectedType)) + { expectedType = new Type("Set", null); } else if (tn1.equals(tn2)) { } else if (expectedType.isEnumeration() && t.isEnumeration()) @@ -5047,6 +5104,7 @@ else if (expectedType.isEnumeration() && } } } + return expectedType; }