diff --git a/build.sbt b/build.sbt index 2db9f58..4635670 100644 --- a/build.sbt +++ b/build.sbt @@ -3,7 +3,7 @@ organization := "io.shiftleft" scalaVersion := "2.13.1" enablePlugins(GitVersioning) -val cpgVersion = "0.11.338" +val cpgVersion = "0.11.334" val antlrVersion = "4.7.2" libraryDependencies ++= Seq( diff --git a/fuzzyc2cpg.sh b/fuzzyc2cpg.sh index f5573a2..de87277 100755 --- a/fuzzyc2cpg.sh +++ b/fuzzyc2cpg.sh @@ -3,4 +3,4 @@ SCRIPT_ABS_PATH=$(readlink -f "$0") SCRIPT_ABS_DIR=$(dirname $SCRIPT_ABS_PATH) -$SCRIPT_ABS_DIR/target/universal/stage/bin/fuzzyc2cpg -J-XX:+UseG1GC -J-XX:CompressedClassSpaceSize=128m -J-XX:+UseStringDeduplication -Dlogback.configurationFile=$SCRIPT_ABS_DIR/config/logback.xml $@ +$SCRIPT_ABS_DIR/target/universal/stage/bin/fuzzyc2cpg -Dlogback.configurationFile=$SCRIPT_ABS_DIR/config/logback.xml $@ diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/ast/statements/jump/GotoStatement.java b/src/main/java/io/shiftleft/fuzzyc2cpg/ast/statements/jump/GotoStatement.java index 9315242..c3cce05 100644 --- a/src/main/java/io/shiftleft/fuzzyc2cpg/ast/statements/jump/GotoStatement.java +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/ast/statements/jump/GotoStatement.java @@ -12,10 +12,6 @@ public String getTargetName() { return getChild(0).getEscapedCodeStr(); } - public String getEscapedCodeStr() { - return "goto " + getTargetName() + ";"; - } - public void accept(ASTNodeVisitor visitor) { visitor.visit(this); } diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/output/inmemory/OutputModule.java b/src/main/java/io/shiftleft/fuzzyc2cpg/output/inmemory/OutputModule.java new file mode 100644 index 0000000..1435a3f --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/output/inmemory/OutputModule.java @@ -0,0 +1,47 @@ +package io.shiftleft.fuzzyc2cpg.output.inmemory; + +import io.shiftleft.codepropertygraph.Cpg; +import io.shiftleft.codepropertygraph.cpgloading.ProtoCpgLoader; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModule; +import overflowdb.OdbConfig; +import io.shiftleft.proto.cpg.Cpg.CpgStruct; + +import java.util.LinkedList; +import java.util.List; + +public class OutputModule implements CpgOutputModule { + + private LinkedList cpgBuilders; + private Cpg cpg; + + protected OutputModule() { + this.cpgBuilders = new LinkedList<>(); + } + + public Cpg getInternalGraph() { + return cpg; + } + + @Override + public void setOutputIdentifier(String identifier) { } + + @Override + public void persistCpg(CpgStruct.Builder cpg) { + synchronized (cpgBuilders) { + cpgBuilders.add(cpg); + } + } + + public void persist() { + CpgStruct.Builder mergedBuilder = CpgStruct.newBuilder(); + + cpgBuilders.forEach(builder -> { + mergedBuilder.mergeFrom(builder.build()); + }); + + List list = new LinkedList<>(); + list.add(mergedBuilder.build()); + cpg = ProtoCpgLoader.loadFromListOfProtos(list, OdbConfig.withoutOverflow()); + } + +} diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/output/inmemory/OutputModuleFactory.java b/src/main/java/io/shiftleft/fuzzyc2cpg/output/inmemory/OutputModuleFactory.java new file mode 100644 index 0000000..227f23c --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/output/inmemory/OutputModuleFactory.java @@ -0,0 +1,37 @@ +package io.shiftleft.fuzzyc2cpg.output.inmemory; + +import io.shiftleft.codepropertygraph.Cpg; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModule; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModuleFactory; + +public class OutputModuleFactory implements CpgOutputModuleFactory { + + private OutputModule outputModule; + + @Override + public CpgOutputModule create() { + synchronized (this) { + if (outputModule == null) { + outputModule = new OutputModule(); + } + } + return outputModule; + } + + /** + * An internal representation of the graph. + * + * @return the internally constructed graph + */ + public Cpg getInternalGraph() { + return outputModule.getInternalGraph(); + } + + @Override + public void persist() { + if (outputModule != null) { + outputModule.persist(); + } + } +} + diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/OutputModule.java b/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/OutputModule.java new file mode 100644 index 0000000..0a693e7 --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/OutputModule.java @@ -0,0 +1,98 @@ +package io.shiftleft.fuzzyc2cpg.output.protobuf; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModule; +import io.shiftleft.proto.cpg.Cpg.CpgStruct; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.concurrent.ThreadLocalRandom; + +public class OutputModule implements CpgOutputModule { + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final Path protoTempDir; + private final boolean writeToDisk; + + private String outputIdentifier; + + public OutputModule(boolean writeToDisk, + Path protoTempDir) { + this.writeToDisk = writeToDisk; + this.protoTempDir = protoTempDir; + } + + @Override + public void setOutputIdentifier(String identifier) { + outputIdentifier = identifier; + } + + /** + * This is called for each code property graph. There is one + * code property graph per method, and one graph for the overall + * program structure. + * */ + + @Override + public void persistCpg(CpgStruct.Builder cpg) throws IOException { + CpgStruct buildCpg = cpg.build(); + if (writeToDisk) { + String outputFilename = getOutputFileName(); + try (FileOutputStream outStream = new FileOutputStream(outputFilename)) { + buildCpg.writeTo(outStream); + } + } + } + + /** + * The complete handling for an already existing file should not be necessary. + * This was added as a last resort to not get incomplete cpgs. + * In case we have a hash collision, the resulting cpg part file names will not + * be identical over different java2cpg runs. + */ + private String getOutputFileName() { + String outputFilename = null; + int postfix = 0; + boolean fileExists = true; + int resolveAttemptCounter = 0; + + while (fileExists && resolveAttemptCounter < 10) { + outputFilename = generateOutputFilename(postfix); + if (Files.exists(Paths.get(outputFilename))) { + postfix = ThreadLocalRandom.current().nextInt(0, 100000); + + logger.warn("Hash collision identifier={}, postfix={}." + + " Retry with random postfix.", outputIdentifier, postfix); + + resolveAttemptCounter++; + } else { + fileExists = false; + } + } + + if (fileExists) { + logger.error("Unable to resolve hash collision. Cpg will be incomplete"); + } + + return outputFilename; + } + + private String generateOutputFilename(int postfix) { + HashFunction hashFunction = Hashing.murmur3_128(); + + Hasher hasher = hashFunction.newHasher(); + hasher.putUnencodedChars(outputIdentifier); + hasher.putInt(postfix); + + String protoSuffix = ".bin"; + return protoTempDir.toString() + File.separator + hasher.hash() + protoSuffix; + } +} diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/OutputModuleFactory.java b/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/OutputModuleFactory.java new file mode 100644 index 0000000..b401395 --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/OutputModuleFactory.java @@ -0,0 +1,62 @@ +package io.shiftleft.fuzzyc2cpg.output.protobuf; + +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModule; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModuleFactory; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.io.FileUtils; + +public class OutputModuleFactory implements CpgOutputModuleFactory { + + private final List outputModules = new ArrayList<>(); + private final boolean writeToDisk; + private final Path protoTempDir; + private final String outputFilename; + + public OutputModuleFactory(String outputFilename, + boolean writeToDisk) throws IOException { + this.writeToDisk = writeToDisk; + this.protoTempDir = Files.createTempDirectory("proto"); + this.outputFilename = outputFilename; + } + + @Override + public CpgOutputModule create() { + OutputModule outputModule = new OutputModule(writeToDisk, protoTempDir); + synchronized (this) { + outputModules.add(outputModule); + } + return outputModule; + } + + /** + * Store collected CPGs into the output directory specified + * for this output module. + * Note: This method should be called only once all intermediate CPGs + * have been processed and collected. + * If the output module was configured to combine intermediate CPGs into a single + * one, we will combine individual proto files. + * */ + @Override + public void persist() throws IOException { + if (writeToDisk) { + try { + ThreadedZipper threadedZipper = new ThreadedZipper(protoTempDir, outputFilename); + threadedZipper.start(); + // wait until the thread is finished + // if we don't wait, the output folder + // may be deleted and and we get null pointer + threadedZipper.join(); + } catch (InterruptedException interruptedException) { + throw new IOException(interruptedException); + } + } + if (this.protoTempDir != null && Files.exists(this.protoTempDir)) { + FileUtils.deleteDirectory(this.protoTempDir.toFile()); + } + } +} diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/ThreadedZipper.java b/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/ThreadedZipper.java new file mode 100644 index 0000000..8d568c7 --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/output/protobuf/ThreadedZipper.java @@ -0,0 +1,158 @@ +package io.shiftleft.fuzzyc2cpg.output.protobuf; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.FileSystem; +import java.nio.file.FileSystems; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +class ThreadedZipper extends Thread { + private Logger logger = LoggerFactory.getLogger(getClass()); + private Path protoDir; + private String outputFile; + private static long TIMEOUT = Long.MAX_VALUE; + + ThreadedZipper(Path protoDir, String outputFile) { + this.protoDir = protoDir; + this.outputFile = outputFile; + } + + private void doCopy(ZipEntry entry) { + try { + logger.debug("copying from " + entry.from + " to " + entry.to); + Files.copy(entry.from, entry.to, StandardCopyOption.REPLACE_EXISTING); + } catch (IOException exception) { + logger.error("Failed to copy files in the zipper", exception); + throw new RuntimeException(exception); + } + } + + @Override + public void run() { + ExecutorService pool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + try { + // create zip file system + Map env = new HashMap<>(); + env.put("create", "true"); + + // handle the -n parameter + Path path = Paths.get(this.outputFile); + URI zipUri; + try { + zipUri = URI.create("jar:" + path.toUri().toString()); + } catch (Exception exception) { + pool.shutdownNow(); + logger.error( + "Failed to create URI using path " + path.toAbsolutePath().toString(), exception); + throw new RuntimeException(exception); + } + if (Files.exists(path)) { + Files.delete(path); + } + logger.debug("Writing file to: " + path); + try (FileSystem zipFileSystem = FileSystems.newFileSystem(zipUri, env)) { + + // if the input is a file + File inputFile = protoDir.toFile(); + if (inputFile.isFile()) { + // TODO: should crash, we expect a directory + logger.debug("Zipping " + inputFile.getName() + " file"); + doCopy(new ZipEntry(inputFile, zipFileSystem)); + } else { + // loop over sorted files + File[] files = inputFile.listFiles(); + if (files == null) { + logger.error("Couldn't list files in " + inputFile); + return; + } + Arrays.sort(files); + + List entries = Arrays.stream(files).flatMap(f -> { + if (f.isFile()) { + return Stream.of(new ZipEntry(f, zipFileSystem)); + } else { + return Stream.empty(); + } + }).collect(Collectors.toList()); + + // create the few special entries in sorted order, while the thread pool will create + // the rest randomly + Iterator iterator = entries.iterator(); + while (iterator.hasNext()) { + ZipEntry entry = iterator.next(); + + if (!entry.getTo().getFileName().toString().startsWith("$")) { + break; + } + + doCopy(entry); + + iterator.remove(); + } + // force create ZIP entries in sorted order before copying in thread pool + entries.forEach(entry -> + pool.submit(new Runnable() { + public void run() { + doCopy(entry); + } + }) + ); + + // NOTE: Abandon hope all ye who enter here. + // Before you even start asking yourself if that's really the right order... + // Yes, that's the right way to shutdown and await for the pool + // of executors to finish. Otherwise, you might still be having some race-conditions. + // That's what the book says. Move on. + pool.shutdown(); + if (!pool.awaitTermination(TIMEOUT, TimeUnit.SECONDS)) { + logger.error("Failed to finish tasks in a timely manner. Cleaning up..."); + pool.shutdownNow(); + if (!pool.awaitTermination(TIMEOUT, TimeUnit.SECONDS)) { + throw new RuntimeException("Executor pool didn't terminate on time"); + } + } + } + } + } catch (IOException | InterruptedException exception) { + pool.shutdownNow(); + logger.error("Failed to create the zip file", exception); + throw new RuntimeException(exception); + } + } + + private class ZipEntry { + private Path from; + private Path to; + + public ZipEntry(File file, FileSystem fs) { + this.from = Paths.get(file.getAbsolutePath()); + this.to = fs.getPath(from.getFileName().toString()); + } + + public Path getFrom() { + return from; + } + + public Path getTo() { + return to; + } + } +} diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/AntlrParserDriver.java b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/AntlrParserDriver.java new file mode 100644 index 0000000..512b81c --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/AntlrParserDriver.java @@ -0,0 +1,305 @@ +package io.shiftleft.fuzzyc2cpg.parser; + +import io.shiftleft.fuzzyc2cpg.Utils; +import io.shiftleft.fuzzyc2cpg.ast.AstNode; +import io.shiftleft.fuzzyc2cpg.ast.AstNodeBuilder; +import io.shiftleft.fuzzyc2cpg.ast.logical.statements.CompoundStatement; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModule; +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModuleFactory; +import io.shiftleft.passes.KeyPool; +import io.shiftleft.proto.cpg.Cpg.CpgStruct; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Stack; +import java.util.function.Consumer; + +import io.shiftleft.proto.cpg.Cpg; +import jdk.nashorn.internal.runtime.ParserException; +import org.antlr.v4.runtime.*; +import org.antlr.v4.runtime.misc.ParseCancellationException; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.ParseTreeListener; +import org.antlr.v4.runtime.tree.ParseTreeWalker; + +import static org.antlr.v4.runtime.Token.EOF; + +abstract public class AntlrParserDriver { + // TODO: This class does two things: + // * It is a driver for the ANTLRParser, i.e., the parser + // that creates ParseTrees from Strings. It can also already + // 'walk' the ParseTree to create ASTs. + // * It is an AST provider in that it will notify watchers + // when ASTs are ready. + // We should split this into two classes. + + public Stack> builderStack = new Stack<>(); + public TokenSubStream stream; + public String filename; + + private Parser antlrParser; + private ParseTreeListener listener; + private CommonParserContext context = null; + + private List observers = new ArrayList<>(); + private Cpg.CpgStruct.Builder cpg = CpgStruct.newBuilder(); + private Cpg.CpgStruct.Node fileNode; + private KeyPool keyPool; + private CpgOutputModuleFactory outputModuleFactory; + + public AntlrParserDriver() { + super(); + } + + public void setOutputModuleFactory(CpgOutputModuleFactory factory) { + this.outputModuleFactory = factory; + } + + public void setKeyPool(KeyPool keyPool) { + this.keyPool = keyPool; + } + + public void setFileNode(Cpg.CpgStruct.Node fileNode) { + this.fileNode = fileNode; + } + + public abstract ParseTree parseTokenStreamImpl(TokenSubStream tokens); + + public abstract Lexer createLexer(CharStream input); + + public void parseAndWalkFile(String filename) throws ParserException, IOException { + handleHiddenTokens(filename); + TokenSubStream stream = createTokenStreamFromFile(filename); + initializeContextWithFile(filename, stream); + + ParseTree tree = parseTokenStream(stream); + walkTree(tree); + + CpgOutputModule outputModule = outputModuleFactory.create(); + outputModule.setOutputIdentifier( + filename + " driver" + ); + outputModule.persistCpg(cpg); + } + + private void handleHiddenTokens(String filename) { + CommonTokenStream tokenStream = createStreamOfHiddenTokensFromFile(filename); + TokenSource tokenSource = tokenStream.getTokenSource(); + + while (true){ + Token token = tokenSource.nextToken(); + if (token.getType() == EOF) { + break; + } + if (token.getChannel() != Token.HIDDEN_CHANNEL) { + continue; + } + int line = token.getLine(); + String text = token.getText(); + // We can add to `CPG` here + + Cpg.CpgStruct.Node commentNode = Utils.newNode(Cpg.CpgStruct.Node.NodeType.COMMENT) + .setKey(keyPool.next()) + .addProperty(Cpg.CpgStruct.Node.Property.newBuilder() + .setName(Cpg.NodePropertyName.LINE_NUMBER) + .setValue(Cpg.PropertyValue.newBuilder().setIntValue(line))) + .addProperty(Cpg.CpgStruct.Node.Property.newBuilder() + .setName(Cpg.NodePropertyName.CODE) + .setValue(Cpg.PropertyValue.newBuilder().setStringValue(text)) + ) + .build(); + + cpg.addNode(commentNode); + + cpg.addEdge(Cpg.CpgStruct.Edge.newBuilder() + .setType(Cpg.CpgStruct.Edge.EdgeType.AST) + .setSrc(fileNode.getKey()) + .setDst(commentNode.getKey()) + ); + + } + } + + public void parseAndWalkTokenStream(TokenSubStream tokens) + throws ParserException { + filename = ""; + stream = tokens; + ParseTree tree = parseTokenStream(tokens); + walkTree(tree); + } + + public ParseTree parseAndWalkString(String input) throws ParserException { + ParseTree tree = parseString(input); + walkTree(tree); + return tree; + } + + public ParseTree parseTokenStream(TokenSubStream tokens) + throws ParserException { + ParseTree returnTree = parseTokenStreamImpl(tokens); + if (returnTree == null) { + throw new ParserException(""); + } + return returnTree; + } + + public ParseTree parseString(String input) throws ParserException { + CharStream inputStream = CharStreams.fromString(input); + Lexer lex = createLexer(inputStream); + TokenSubStream tokens = new TokenSubStream(lex); + ParseTree tree = parseTokenStream(tokens); + return tree; + } + + protected TokenSubStream createTokenStreamFromFile(String filename) + throws ParserException { + + CharStream input = createInputStreamForFile(filename); + Lexer lexer = createLexer(input); + TokenSubStream tokens = new TokenSubStream(lexer); + return tokens; + + } + + private CharStream createInputStreamForFile(String filename) { + + try { + return CharStreams.fromFileName(filename); + } catch (IOException exception) { + throw new RuntimeException(String.format("Unable to find source file [%s]", filename)); + } + + } + + protected CommonTokenStream createStreamOfHiddenTokensFromFile(String filename) { + CharStream input = createInputStreamForFile(filename); + Lexer lexer = createLexer(input); + return new CommonTokenStream(lexer, Token.HIDDEN_CHANNEL); + } + + protected void walkTree(ParseTree tree) { + ParseTreeWalker walker = new ParseTreeWalker(); + walker.walk(getListener(), tree); + } + + protected void initializeContextWithFile(String filename, + TokenSubStream stream) { + setContext(new CommonParserContext()); + getContext().filename = filename; + getContext().stream = stream; + initializeContext(getContext()); + } + + protected boolean isRecognitionException(RuntimeException ex) { + + return ex.getClass() == ParseCancellationException.class + && ex.getCause() instanceof RecognitionException; + } + + protected void setLLStarMode(Parser parser) { + parser.removeErrorListeners(); + // parser.addErrorListener(ConsoleErrorListener.INSTANCE); + parser.setErrorHandler(new DefaultErrorStrategy()); + // parser.getInterpreter().setPredictionMode(PredictionMode.LL); + } + + protected void setSLLMode(Parser parser) { + // parser.getInterpreter().setPredictionMode(PredictionMode.SLL); + parser.removeErrorListeners(); + parser.setErrorHandler(new BailErrorStrategy()); + } + + public void initializeContext(CommonParserContext context) { + filename = context.filename; + stream = context.stream; + } + + public void setStack(Stack> aStack) { + builderStack = aStack; + } + + // ////////////////// + + public void addObserver(AntlrParserDriverObserver observer) { + observers.add(observer); + } + + private void notifyObservers(Consumer function) { + for (AntlrParserDriverObserver observer : observers) { + function.accept(observer); + } + + } + + public void begin() { + notifyObserversOfBegin(); + } + + public void end() { + notifyObserversOfEnd(); + } + + private void notifyObserversOfBegin() { + notifyObservers(AntlrParserDriverObserver::begin); + } + + private void notifyObserversOfEnd() { + notifyObservers(AntlrParserDriverObserver::end); + } + + public void notifyObserversOfUnitStart(ParserRuleContext ctx) { + notifyObservers(new Consumer() { + @Override + public void accept(AntlrParserDriverObserver observer) { + observer.startOfUnit(ctx, filename); + } + }); + } + + public void notifyObserversOfUnitEnd(ParserRuleContext ctx) { + notifyObservers(new Consumer() { + @Override + public void accept(AntlrParserDriverObserver observer) { + observer.endOfUnit(ctx, filename); + } + }); + } + + public void notifyObserversOfItem(AstNode aItem) { + notifyObservers(new Consumer() { + @Override + public void accept(AntlrParserDriverObserver observer) { + observer.processItem(aItem, builderStack); + } + }); + } + + public CompoundStatement getResult() { + return (CompoundStatement) builderStack.peek().getItem(); + } + + public Parser getAntlrParser() { + return antlrParser; + } + + public void setAntlrParser(Parser aParser) { + antlrParser = aParser; + } + + public ParseTreeListener getListener() { + return listener; + } + + public void setListener(ParseTreeListener listener) { + this.listener = listener; + } + + public CommonParserContext getContext() { + return context; + } + + public void setContext(CommonParserContext context) { + this.context = context; + } + +} \ No newline at end of file diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/AntlrCFunctionParserDriver.java b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/AntlrCFunctionParserDriver.java index 97e92f7..d12064b 100644 --- a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/AntlrCFunctionParserDriver.java +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/AntlrCFunctionParserDriver.java @@ -2,8 +2,8 @@ import io.shiftleft.fuzzyc2cpg.FunctionLexer; import io.shiftleft.fuzzyc2cpg.FunctionParser; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; import io.shiftleft.fuzzyc2cpg.parser.TokenSubStream; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; import org.antlr.v4.runtime.CharStream; import org.antlr.v4.runtime.Lexer; import org.antlr.v4.runtime.tree.ParseTree; diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/CFunctionParseTreeListener.java b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/CFunctionParseTreeListener.java index 3ff83b6..feaad5a 100644 --- a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/CFunctionParseTreeListener.java +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/functions/CFunctionParseTreeListener.java @@ -2,8 +2,8 @@ import io.shiftleft.fuzzyc2cpg.FunctionBaseListener; import io.shiftleft.fuzzyc2cpg.FunctionParser; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; import io.shiftleft.fuzzyc2cpg.parser.functions.builder.FunctionContentBuilder; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; /** * This is where hooks are registered for different types of parse tree nodes. diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/AntlrCModuleParserDriver.java b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/AntlrCModuleParserDriver.java new file mode 100644 index 0000000..f5d77f1 --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/AntlrCModuleParserDriver.java @@ -0,0 +1,42 @@ +package io.shiftleft.fuzzyc2cpg.parser.modules; + +import io.shiftleft.fuzzyc2cpg.ModuleLexer; +import io.shiftleft.fuzzyc2cpg.ModuleParser; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.TokenSubStream; +import org.antlr.v4.runtime.CharStream; +import org.antlr.v4.runtime.Lexer; +import org.antlr.v4.runtime.tree.ParseTree; + +public class AntlrCModuleParserDriver extends AntlrParserDriver { + + public AntlrCModuleParserDriver() { + super(); + setListener(new CModuleParserTreeListener(this)); + } + + @Override + public ParseTree parseTokenStreamImpl(TokenSubStream tokens) { + ModuleParser parser = new ModuleParser(tokens); + setAntlrParser(parser); + ParseTree tree = null; + + try { + setSLLMode(parser); + tree = parser.code(); + } catch (RuntimeException ex) { + if (isRecognitionException(ex)) { + tokens.reset(); + setLLStarMode(parser); + tree = parser.code(); + } + } + return tree; + } + + @Override + public Lexer createLexer(CharStream input) { + return new ModuleLexer(input); + } + +} diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/CModuleParserTreeListener.java b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/CModuleParserTreeListener.java new file mode 100644 index 0000000..d961758 --- /dev/null +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/CModuleParserTreeListener.java @@ -0,0 +1,227 @@ +package io.shiftleft.fuzzyc2cpg.parser.modules; + +import java.util.List; + +import org.antlr.v4.runtime.ParserRuleContext; + +import io.shiftleft.fuzzyc2cpg.ModuleBaseListener; +import io.shiftleft.fuzzyc2cpg.ModuleParser; +import io.shiftleft.fuzzyc2cpg.ModuleParser.Class_defContext; +import io.shiftleft.fuzzyc2cpg.ModuleParser.DeclByClassContext; +import io.shiftleft.fuzzyc2cpg.ModuleParser.Init_declarator_listContext; +import io.shiftleft.fuzzyc2cpg.ModuleParser.Type_nameContext; +import io.shiftleft.fuzzyc2cpg.ast.declarations.IdentifierDecl; +import io.shiftleft.fuzzyc2cpg.ast.logical.statements.CompoundStatement; +import io.shiftleft.fuzzyc2cpg.ast.statements.IdentifierDeclStatement; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.CompoundItemAssembler; +import io.shiftleft.fuzzyc2cpg.parser.ModuleFunctionParserInterface; +import io.shiftleft.fuzzyc2cpg.parser.modules.builder.FunctionDefBuilder; +import io.shiftleft.fuzzyc2cpg.parser.shared.builders.ClassDefBuilder; +import io.shiftleft.fuzzyc2cpg.parser.shared.builders.IdentifierDeclBuilder; +import io.shiftleft.fuzzyc2cpg.parser.shared.builders.TemplateAstBuilder; + +// Converts Parse Trees to ASTs for Modules + +public class CModuleParserTreeListener extends ModuleBaseListener { + + private final AntlrParserDriver p; + + public CModuleParserTreeListener(AntlrParserDriver p) { + this.p = p; + } + + @Override + public void enterCode(ModuleParser.CodeContext ctx) { + p.notifyObserversOfUnitStart(ctx); + } + + @Override + public void exitCode(ModuleParser.CodeContext ctx) { + p.notifyObserversOfUnitEnd(ctx); + } + + @Override + public void enterFunction_decl(ModuleParser.Function_declContext ctx) { + FunctionDefBuilder builder = new FunctionDefBuilder(); + builder.createNew(ctx); + builder.setIsOnlyDeclaration(true); + builder.setContent(new CompoundStatement()); + p.builderStack.push(builder); + } + + @Override + public void exitFunction_decl(ModuleParser.Function_declContext ctx) { + FunctionDefBuilder builder = (FunctionDefBuilder) p.builderStack.pop(); + p.notifyObserversOfItem(builder.getItem()); + } + + // ///////////////////////////////////////////////////////////// + // This is where the ModuleParser invokes the FunctionParser + // ///////////////////////////////////////////////////////////// + // This function is invoked when a Function_Def parse tree node + // is entered. This is where we hand over the function contents to + // the function parser and connect the AST node created for the + // function definition to the AST created by the function parser. + // //////////////////////////////////////////////////////////////// + + @Override + public void enterFunction_def(ModuleParser.Function_defContext ctx) { + FunctionDefBuilder builder = new FunctionDefBuilder(); + builder.createNew(ctx); + p.builderStack.push(builder); + + CompoundStatement functionContent = ModuleFunctionParserInterface + .parseFunctionContents(ctx); + builder.setContent(functionContent); + } + + @Override + public void exitFunction_def(ModuleParser.Function_defContext ctx) { + FunctionDefBuilder builder = (FunctionDefBuilder) p.builderStack.pop(); + p.notifyObserversOfItem(builder.getItem()); + } + + @Override + public void enterReturn_type(ModuleParser.Return_typeContext ctx) { + FunctionDefBuilder builder = (FunctionDefBuilder) p.builderStack.peek(); + builder.setReturnType(ctx); + } + + @Override + public void enterFunction_name(ModuleParser.Function_nameContext ctx) { + FunctionDefBuilder builder = (FunctionDefBuilder) p.builderStack.peek(); + builder.setName(ctx); + } + + @Override + public void enterFunction_param_list( + ModuleParser.Function_param_listContext ctx) { + FunctionDefBuilder builder = (FunctionDefBuilder) p.builderStack.peek(); + builder.setParameterList(ctx); + } + + @Override + public void enterParameter_decl(ModuleParser.Parameter_declContext ctx) { + FunctionDefBuilder builder = (FunctionDefBuilder) p.builderStack.peek(); + builder.addParameter(ctx); + } + + @Override + public void enterTemplate_decl(ModuleParser.Template_declContext ctx) { + TemplateAstBuilder builder = (TemplateAstBuilder) p.builderStack.peek(); + builder.setTemplateList(ctx); + } + + @Override + public void enterTemplate_name(ModuleParser.Template_nameContext ctx) { + TemplateAstBuilder builder = (TemplateAstBuilder) p.builderStack.peek(); + builder.addTemplateParameter(ctx); + } + + // DeclByType + + @Override + public void enterDeclByType(ModuleParser.DeclByTypeContext ctx) { + Init_declarator_listContext decl_list = ctx.init_declarator_list(); + Type_nameContext typeName = ctx.type_name(); + IdentifierDeclBuilder builder = new IdentifierDeclBuilder(); + p.builderStack.push(builder); + emitDeclarations(builder, decl_list, typeName, ctx); + } + + @Override + public void exitDeclByType(ModuleParser.DeclByTypeContext ctx) { + p.builderStack.pop(); + } + + private void emitDeclarations(IdentifierDeclBuilder identifierDeclBuilder, + ParserRuleContext decl_list, + ParserRuleContext typeName, + ParserRuleContext ctx) { + + List declarations = identifierDeclBuilder.getDeclarations(decl_list, typeName); + IdentifierDeclStatement stmt = new IdentifierDeclStatement(); + + boolean isTypedef = ctx.getParent().start.getText().equals("typedef"); + + for (IdentifierDecl decl : declarations) { + decl.setIsTypedef(isTypedef); + stmt.addChild(decl); + } + + p.notifyObserversOfItem(stmt); + } + + // DeclByClass + + @Override + public void enterDeclByClass(ModuleParser.DeclByClassContext ctx) { + ClassDefBuilder builder = new ClassDefBuilder(); + builder.createNew(ctx); + p.builderStack.push(builder); + } + + @Override + public void exitDeclByClass(ModuleParser.DeclByClassContext ctx) { + ClassDefBuilder builder = (ClassDefBuilder) p.builderStack.pop(); + + CompoundStatement content = parseClassContent(ctx); + builder.setContent(content); + + if(ctx.class_def().base_classes() != null && ctx.class_def().base_classes().base_class() != null) { + for (ModuleParser.Base_classContext baseClassCtx : ctx.class_def().base_classes().base_class()) { + builder.addBaseClass(baseClassCtx); + } + } + + p.notifyObserversOfItem(builder.getItem()); + emitDeclarationsForClass(ctx); + } + + @Override + public void enterClass_name(ModuleParser.Class_nameContext ctx) { + ClassDefBuilder builder = (ClassDefBuilder) p.builderStack.peek(); + builder.setName(ctx); + } + + private void emitDeclarationsForClass(DeclByClassContext ctx) { + + Init_declarator_listContext decl_list = ctx.init_declarator_list(); + if (decl_list == null) { + return; + } + + ParserRuleContext typeName = ctx.class_def().class_name(); + emitDeclarations(new IdentifierDeclBuilder(), decl_list, typeName, ctx); + } + + private CompoundStatement parseClassContent( + ModuleParser.DeclByClassContext ctx) { + AntlrCModuleParserDriver shallowParser = createNewShallowParser(); + CompoundItemAssembler generator = new CompoundItemAssembler(); + shallowParser.addObserver(generator); + + restrictStreamToClassContent(ctx); + shallowParser.parseAndWalkTokenStream(p.stream); + p.stream.resetRestriction(); + + return generator.getCompoundItem(); + } + + private void restrictStreamToClassContent( + ModuleParser.DeclByClassContext ctx) { + Class_defContext class_def = ctx.class_def(); + int startIndex = class_def.OPENING_CURLY().getSymbol().getTokenIndex(); + int stopIndex = class_def.stop.getTokenIndex(); + + p.stream.restrict(startIndex + 1, stopIndex); + } + + private AntlrCModuleParserDriver createNewShallowParser() { + AntlrCModuleParserDriver shallowParser = new AntlrCModuleParserDriver(); + shallowParser.setStack(p.builderStack); + return shallowParser; + } + +} diff --git a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/builder/FunctionDefBuilder.java b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/builder/FunctionDefBuilder.java index e56d48d..ad2a126 100644 --- a/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/builder/FunctionDefBuilder.java +++ b/src/main/java/io/shiftleft/fuzzyc2cpg/parser/modules/builder/FunctionDefBuilder.java @@ -1,8 +1,12 @@ package io.shiftleft.fuzzyc2cpg.parser.modules.builder; +import java.util.Stack; + import org.antlr.v4.runtime.ParserRuleContext; import io.shiftleft.fuzzyc2cpg.ModuleParser.*; +import io.shiftleft.fuzzyc2cpg.ast.AstNode; +import io.shiftleft.fuzzyc2cpg.ast.AstNodeBuilder; import io.shiftleft.fuzzyc2cpg.ast.expressions.Identifier; import io.shiftleft.fuzzyc2cpg.ast.functionDef.ReturnType; import io.shiftleft.fuzzyc2cpg.ast.langc.functiondef.FunctionDef; diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/AstVisitor.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/AstVisitor.scala new file mode 100644 index 0000000..4c7e32d --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/AstVisitor.scala @@ -0,0 +1,99 @@ +package io.shiftleft.fuzzyc2cpg + +import java.util + +import io.shiftleft.fuzzyc2cpg.adapter.ProtoCpgAdapter +import io.shiftleft.fuzzyc2cpg.ast.{AstNode, AstNodeBuilder} +import io.shiftleft.fuzzyc2cpg.ast.declarations.ClassDefStatement +import io.shiftleft.fuzzyc2cpg.ast.langc.functiondef.FunctionDef +import io.shiftleft.fuzzyc2cpg.ast.statements.IdentifierDeclStatement +import io.shiftleft.fuzzyc2cpg.ast.walking.ASTNodeVisitor +import io.shiftleft.fuzzyc2cpg.astnew.AstToCpgConverter +import io.shiftleft.fuzzyc2cpg.cfg.AstToCfgConverter +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModuleFactory +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriverObserver +import io.shiftleft.passes.KeyPool +import io.shiftleft.proto.cpg.Cpg.CpgStruct +import io.shiftleft.proto.cpg.Cpg.CpgStruct.Node +import org.antlr.v4.runtime.ParserRuleContext + +class AstVisitor(outputModuleFactory: CpgOutputModuleFactory, + astParentNode: Node, + keyPool: KeyPool, + cache: FuzzyC2CpgCache, + global: Global) + extends ASTNodeVisitor + with AntlrParserDriverObserver { + private var fileNameOption = Option.empty[String] + private val structureCpg = CpgStruct.newBuilder() + + /** + * Callback triggered for each function definition + * */ + override def visit(functionDef: FunctionDef): Unit = { + val outputModule = outputModuleFactory.create() + val outputIdentifier = s"${fileNameOption.get}${functionDef.getName}" + + s"${functionDef.getLocation.startLine}${functionDef.getLocation.endLine}" + outputModule.setOutputIdentifier(outputIdentifier) + + val bodyCpg = CpgStruct.newBuilder() + val cpgAdapter = new ProtoCpgAdapter(bodyCpg, keyPool) + val astToCpgConverter = + new AstToCpgConverter(astParentNode, cpgAdapter, global) + astToCpgConverter.convert(functionDef) + + val astToCfgConverter = + new AstToCfgConverter(astToCpgConverter.getMethodNode.get, astToCpgConverter.getMethodReturnNode.get, cpgAdapter) + astToCfgConverter.convert(functionDef) + + if (functionDef.isOnlyDeclaration) { + // Do not persist the declaration. It may be that we encounter a + // corresponding definition, in which case the declaration will be + // removed again and is never persisted. Persisting of declarations + // happens after concurrent processing of compilation units. + cache.add(functionDef.getFunctionSignature(false), outputIdentifier, bodyCpg) + } else { + cache.remove(functionDef.getFunctionSignature(false)) + outputModule.persistCpg(bodyCpg) + } + } + + /** + * Callback triggered for every class/struct + * */ + override def visit(classDefStatement: ClassDefStatement): Unit = { + val cpgAdapter = new ProtoCpgAdapter(structureCpg, keyPool) + val astToCpgConverter = + new AstToCpgConverter(astParentNode, cpgAdapter, global) + astToCpgConverter.convert(classDefStatement) + } + + /** + * Callback triggered for every global identifier declaration + * */ + override def visit(identifierDeclStmt: IdentifierDeclStatement): Unit = { + val cpgAdapter = new ProtoCpgAdapter(structureCpg, keyPool) + val astToCpgConverter = + new AstToCpgConverter(astParentNode, cpgAdapter, global) + astToCpgConverter.convert(identifierDeclStmt) + } + + override def begin(): Unit = {} + + override def end(): Unit = {} + + override def startOfUnit(ctx: ParserRuleContext, filename: String): Unit = { + fileNameOption = Some(filename) + } + + override def endOfUnit(ctx: ParserRuleContext, filename: String): Unit = { + val identifier = s"$filename types" + val outputModule = outputModuleFactory.create() + outputModule.setOutputIdentifier(identifier) + outputModule.persistCpg(structureCpg) + } + + override def processItem[T <: AstNode](node: T, builderStack: util.Stack[AstNodeBuilder[_ <: AstNode]]): Unit = { + node.accept(this) + } +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2Cpg.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2Cpg.scala index 0b24f73..a728fed 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2Cpg.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2Cpg.scala @@ -1,30 +1,43 @@ package io.shiftleft.fuzzyc2cpg import org.slf4j.LoggerFactory -import java.nio.file.Files -import java.util.concurrent.ConcurrentHashMap -import better.files.File -import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.fuzzyc2cpg.passes.{AstCreationPass, CMetaDataPass, CfgCreationPass, StubRemovalPass, TypeNodePass} -import io.shiftleft.passes.IntervalKeyPool -import overflowdb.{OdbConfig, OdbGraph} +import io.shiftleft.codepropertygraph.generated.Languages +import io.shiftleft.fuzzyc2cpg.Utils.{getGlobalNamespaceBlockFullName, newEdge, newNode, _} +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModuleFactory +import io.shiftleft.fuzzyc2cpg.output.protobuf.OutputModuleFactory +import io.shiftleft.fuzzyc2cpg.parser.modules.AntlrCModuleParserDriver +import io.shiftleft.proto.cpg.Cpg.CpgStruct.Edge.EdgeType +import io.shiftleft.proto.cpg.Cpg.CpgStruct.Node +import io.shiftleft.proto.cpg.Cpg.CpgStruct.Node.NodeType +import io.shiftleft.proto.cpg.Cpg.{CpgStruct, NodePropertyName} +import java.nio.file.{Files, Path} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} + +import io.shiftleft.passes.KeyPool + import scala.collection.mutable.ListBuffer +import scala.collection.parallel.CollectionConverters._ import scala.util.control.NonFatal import scala.jdk.CollectionConverters._ case class Global(usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap[String, Boolean]()) -class FuzzyC2Cpg() { +class FuzzyC2Cpg(outputModuleFactory: CpgOutputModuleFactory) { import FuzzyC2Cpg.logger + def this(outputPath: String) = { + this(new OutputModuleFactory(outputPath, true).asInstanceOf[CpgOutputModuleFactory]) + } + + private val cache = new FuzzyC2CpgCache + def runWithPreprocessorAndOutput(sourcePaths: Set[String], sourceFileExtensions: Set[String], includeFiles: Set[String], includePaths: Set[String], defines: Set[String], undefines: Set[String], - preprocessorExecutable: String, - optionalOutputPath: Option[String] = None): Unit = { + preprocessorExecutable: String): Unit = { // Create temp dir to store preprocessed source. val preprocessedPath = Files.createTempDirectory("fuzzyc2cpg_preprocessed_") logger.info(s"Writing preprocessed files to [$preprocessedPath]") @@ -55,55 +68,161 @@ class FuzzyC2Cpg() { if (exitCode == 0) { logger.info(s"Preprocessing complete, files written to [$preprocessedPath], starting CPG generation...") - val cpg = runAndOutput(Set(preprocessedPath.toString), sourceFileExtensions, optionalOutputPath) - cpg.close() + runAndOutput(Set(preprocessedPath.toString), sourceFileExtensions) } else { logger.error( s"Error occurred whilst running preprocessor. Log written to [$preprocessorLogFile]. Exit code [$exitCode].") } } - def runAndOutput(sourcePaths: Set[String], - sourceFileExtensions: Set[String], - optionalOutputPath: Option[String] = None): Cpg = { - val metaDataKeyPool = new IntervalKeyPool(1, 100) - val typesKeyPool = new IntervalKeyPool(100, 1000100) - val functionKeyPools = KeyPools.obtain(2, 1000101) - - val cpg = initCpg(optionalOutputPath) + def runAndOutput(sourcePaths: Set[String], sourceFileExtensions: Set[String]): Unit = { val sourceFileNames = SourceFiles.determine(sourcePaths, sourceFileExtensions) + val keyPools = KeyPools.obtain(sourceFileNames.size.toLong + 2) - new CMetaDataPass(cpg, Some(metaDataKeyPool)).createAndApply() - val astCreator = new AstCreationPass(sourceFileNames, cpg, functionKeyPools.head) - astCreator.createAndApply() - new CfgCreationPass(cpg, functionKeyPools.last).createAndApply() - new StubRemovalPass(cpg).createAndApply() - new TypeNodePass(astCreator.global.usedTypes.keys().asScala.toList, cpg, Some(typesKeyPool)).createAndApply() - cpg + val fileAndNamespaceKeyPool = keyPools.head + val typesKeyPool = keyPools(1) + val compilationUnitKeyPools = keyPools.slice(2, keyPools.size) + + addFilesAndNamespaces(fileAndNamespaceKeyPool) + val global = addCompilationUnits(sourceFileNames, compilationUnitKeyPools) + addFunctionDeclarations(cache) + addTypeNodes(global.usedTypes, typesKeyPool) + outputModuleFactory.persist() } - /** - * Create an empty CPG, backed by the file at `optionalOutputPath` or - * in-memory if `optionalOutputPath` is empty. - * */ - private def initCpg(optionalOutputPath: Option[String]): Cpg = { - val odbConfig = optionalOutputPath - .map { outputPath => - val outFile = File(outputPath) - if (outputPath != "" && outFile.exists) { - logger.info("Output file exists, removing: " + outputPath) - outFile.delete() - } - OdbConfig.withDefaults.withStorageLocation(outputPath) + private def addFilesAndNamespaces(keyPool: KeyPool): Unit = { + val fileAndNamespaceCpg = CpgStruct.newBuilder() + createStructuralCpg(keyPool, fileAndNamespaceCpg) + val outputModule = outputModuleFactory.create() + outputModule.setOutputIdentifier("__structural__") + outputModule.persistCpg(fileAndNamespaceCpg) + } + + // TODO improve fuzzyc2cpg namespace support. Currently, everything + // is in the same global namespace so the code below is correct. + private def addCompilationUnits(sourceFileNames: List[String], keyPools: List[KeyPool]): Global = { + val global = Global() + sourceFileNames.zipWithIndex + .map { case (filename, i) => (filename, keyPools(i)) } + .par + .foreach { case (filename, keyPool) => createCpgForCompilationUnit(filename, keyPool, global) } + global + } + + private def addFunctionDeclarations(cache: FuzzyC2CpgCache): Unit = { + cache.sortedSignatures.par.foreach { signature => + cache.getDeclarations(signature).foreach { + case (outputIdentifier, bodyCpg) => + val outputModule = outputModuleFactory.create() + outputModule.setOutputIdentifier(outputIdentifier) + outputModule.persistCpg(bodyCpg) } - .getOrElse { - OdbConfig.withDefaults() + } + } + + private def addTypeNodes(usedTypes: ConcurrentHashMap[String, Boolean], keyPool: KeyPool): Unit = { + val cpg = CpgStruct.newBuilder() + val outputModule = outputModuleFactory.create() + outputModule.setOutputIdentifier("__types__") + createTypeNodes(usedTypes, keyPool, cpg) + outputModule.persistCpg(cpg) + } + + private def fileAndNamespaceGraph(filename: String, keyPool: KeyPool): (Node, Node) = { + + def createFileNode(pathToFile: Path, keyPool: KeyPool): Node = { + newNode(NodeType.FILE) + .setKey(keyPool.next) + .addStringProperty(NodePropertyName.NAME, pathToFile.toAbsolutePath.normalize.toString) + .build() + } + + val cpg = CpgStruct.newBuilder() + val outputModule = outputModuleFactory.create() + outputModule.setOutputIdentifier(filename + " fileAndNamespace") + + val pathToFile = new java.io.File(filename).toPath + val fileNode = createFileNode(pathToFile, keyPool) + val namespaceBlock = createNamespaceBlockNode(Some(pathToFile), keyPool) + cpg.addNode(fileNode) + cpg.addNode(namespaceBlock) + cpg.addEdge(newEdge(EdgeType.AST, namespaceBlock, fileNode)) + outputModule.persistCpg(cpg) + (fileNode, namespaceBlock) + } + + private def createNamespaceBlockNode(filePath: Option[Path], keyPool: KeyPool): Node = { + newNode(NodeType.NAMESPACE_BLOCK) + .setKey(keyPool.next) + .addStringProperty(NodePropertyName.NAME, Defines.globalNamespaceName) + .addStringProperty(NodePropertyName.FULL_NAME, getGlobalNamespaceBlockFullName(filePath.map(_.toString))) + .build + } + + private def createTypeNodes(usedTypes: ConcurrentHashMap[String, Boolean], + keyPool: KeyPool, + cpg: CpgStruct.Builder): Unit = { + usedTypes + .keys() + .asScala + .toList + .sorted + .foreach { typeName => + val node = newNode(NodeType.TYPE) + .setKey(keyPool.next) + .addStringProperty(NodePropertyName.NAME, typeName) + .addStringProperty(NodePropertyName.FULL_NAME, typeName) + .addStringProperty(NodePropertyName.TYPE_DECL_FULL_NAME, typeName) + .build + cpg.addNode(node) } + } + + private def createStructuralCpg(keyPool: KeyPool, cpg: CpgStruct.Builder): Unit = { - val graph = OdbGraph.open(odbConfig, - io.shiftleft.codepropertygraph.generated.nodes.Factories.allAsJava, - io.shiftleft.codepropertygraph.generated.edges.Factories.allAsJava) - new Cpg(graph) + def addMetaDataNode(cpg: CpgStruct.Builder): Unit = { + val metaNode = newNode(NodeType.META_DATA) + .setKey(keyPool.next) + .addStringProperty(NodePropertyName.LANGUAGE, Languages.C) + .build + cpg.addNode(metaNode) + } + + def addAnyTypeAndNamespaceBlock(cpg: CpgStruct.Builder): Unit = { + val globalNamespaceBlockNotInFileNode = createNamespaceBlockNode(None, keyPool) + cpg.addNode(globalNamespaceBlockNotInFileNode) + } + + addMetaDataNode(cpg) + addAnyTypeAndNamespaceBlock(cpg) + } + + private def createCpgForCompilationUnit(filename: String, keyPool: KeyPool, global: Global): Unit = { + val (fileNode, namespaceBlock) = fileAndNamespaceGraph(filename, keyPool) + + // We call the module parser here and register the `astVisitor` to + // receive callbacks as we walk the tree. The method body parser + // will the invoked by `astVisitor` as we walk the tree + + val driver = new AntlrCModuleParserDriver() + val astVisitor = + new AstVisitor(outputModuleFactory, namespaceBlock, keyPool, cache, global) + driver.addObserver(astVisitor) + driver.setKeyPool(keyPool) + driver.setOutputModuleFactory(outputModuleFactory) + driver.setFileNode(fileNode) + + try { + driver.parseAndWalkFile(filename) + } catch { + case ex: RuntimeException => { + logger.warn("Cannot parse module: " + filename + ", skipping") + logger.warn("Complete exception: ", ex) + } + case _: StackOverflowError => { + logger.warn("Cannot parse module: " + filename + ", skipping, StackOverflow") + } + } } } @@ -115,22 +234,27 @@ object FuzzyC2Cpg { def main(args: Array[String]): Unit = { parseConfig(args).foreach { config => try { - val fuzzyc = new FuzzyC2Cpg() + + val factory = if (!config.overflowDb) { + new OutputModuleFactory(config.outputPath, true) + .asInstanceOf[CpgOutputModuleFactory] + } else { + val queue = new LinkedBlockingQueue[CpgStruct.Builder]() + new io.shiftleft.fuzzyc2cpg.output.overflowdb.OutputModuleFactory(config.outputPath, queue) + } + + val fuzzyc = new FuzzyC2Cpg(factory) if (config.usePreprocessor) { - fuzzyc.runWithPreprocessorAndOutput( - config.inputPaths, - config.sourceFileExtensions, - config.includeFiles, - config.includePaths, - config.defines, - config.undefines, - config.preprocessorExecutable, - Some(config.outputPath) - ) + fuzzyc.runWithPreprocessorAndOutput(config.inputPaths, + config.sourceFileExtensions, + config.includeFiles, + config.includePaths, + config.defines, + config.undefines, + config.preprocessorExecutable) } else { - val cpg = fuzzyc.runAndOutput(config.inputPaths, config.sourceFileExtensions, Some(config.outputPath)) - cpg.close() + fuzzyc.runAndOutput(config.inputPaths, config.sourceFileExtensions) } } catch { diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2CpgCache.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2CpgCache.scala new file mode 100644 index 0000000..10fcee0 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/FuzzyC2CpgCache.scala @@ -0,0 +1,55 @@ +package io.shiftleft.fuzzyc2cpg + +import io.shiftleft.proto.cpg.Cpg.CpgStruct + +import scala.collection.mutable + +class FuzzyC2CpgCache { + private val functionDeclarations = new mutable.HashMap[String, mutable.ListBuffer[(String, CpgStruct.Builder)]]() + + /** + * Unless `remove` has been called for `signature`, add (outputIdentifier, cpg) + * pair to the list declarations stored for `signature`. + * */ + def add(signature: String, outputIdentifier: String, cpg: CpgStruct.Builder): Unit = { + functionDeclarations.synchronized { + if (functionDeclarations.contains(signature)) { + val declList = functionDeclarations(signature) + // null is the placeholder that indicates that we've removed + // a function with this signature before, and hence, we do + // not need to add it again + if (declList == null) return + if (declList.nonEmpty) { + declList.append((outputIdentifier, cpg)) + } + } else { + functionDeclarations.put(signature, mutable.ListBuffer((outputIdentifier, cpg))) + } + } + } + + /** + * Register placeholder for `signature` to indicate that + * a function definition exists for this declaration, and + * therefore, no declaration should be written for functions + * with this signature. + * */ + def remove(signature: String): Unit = { + functionDeclarations.synchronized { + functionDeclarations.put(signature, null) + } + } + + def sortedSignatures: List[String] = { + functionDeclarations.synchronized { + functionDeclarations.filter(_._2 != null).keySet.toList.sorted + } + } + + def getDeclarations(signature: String): List[(String, CpgStruct.Builder)] = { + functionDeclarations.synchronized { + functionDeclarations(signature).toList.filter(_._2 != null) + } + } + +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/KeyPools.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/KeyPools.scala index e252fdd..cc53e40 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/KeyPools.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/KeyPools.scala @@ -1,6 +1,6 @@ package io.shiftleft.fuzzyc2cpg -import io.shiftleft.passes.IntervalKeyPool +import io.shiftleft.passes.{IntervalKeyPool, KeyPool} object KeyPools { @@ -8,11 +8,11 @@ object KeyPools { * Divide the keyspace into n intervals and return * a list of corresponding key pools. * */ - def obtain(n: Long, minValue: Long = 0, maxValue: Long = Long.MaxValue): List[IntervalKeyPool] = { + def obtain(n: Long, maxValue: Long = Long.MaxValue): List[KeyPool] = { val nIntervals = Math.max(n, 1) - val intervalLen: Long = (maxValue - minValue) / nIntervals + val intervalLen: Long = maxValue / nIntervals List.range(0, nIntervals).map { i => - val first = i * intervalLen + minValue + val first = i * intervalLen val last = first + intervalLen - 1 new IntervalKeyPool(first, last) } diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/Utils.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/Utils.scala new file mode 100644 index 0000000..5918591 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/Utils.scala @@ -0,0 +1,89 @@ +package io.shiftleft.fuzzyc2cpg + +import io.shiftleft.fuzzyc2cpg.ast.AstNode +import io.shiftleft.proto.cpg.Cpg +import io.shiftleft.proto.cpg.Cpg.CpgStruct.{Edge, Node} +import io.shiftleft.proto.cpg.Cpg.CpgStruct.Node.{NodeType, Property} +import io.shiftleft.proto.cpg.Cpg.{CpgStruct, PropertyValue, StringList} + +object Utils { + + def newStringProperty(name: Cpg.NodePropertyName, value: String): Property.Builder = { + Property.newBuilder + .setName(name) + .setValue(PropertyValue.newBuilder.setStringValue(value).build) + } + + def newIntProperty(name: Cpg.NodePropertyName, value: Int): Property.Builder = { + Property.newBuilder + .setName(name) + .setValue(PropertyValue.newBuilder.setIntValue(value).build) + } + + def newBooleanProperty(name: Cpg.NodePropertyName, value: Boolean): Property.Builder = { + Property.newBuilder + .setName(name) + .setValue(PropertyValue.newBuilder.setBoolValue(value).build) + } + + def newStringListProperty(name: Cpg.NodePropertyName, value: List[String]): Property.Builder = { + val slb = StringList.newBuilder() + value.map { slb.addValues(_) } + slb.build() + Property.newBuilder + .setName(name) + .setValue(PropertyValue.newBuilder.setStringList(slb).build) + } + + def newNode(nodeType: NodeType): Node.Builder = { + Node + .newBuilder() + .setType(nodeType) + } + + def newEdge(edgeType: Edge.EdgeType, dstNode: Node, srcNode: Node): Edge.Builder = { + Edge + .newBuilder() + .setType(edgeType) + .setDst(dstNode.getKey) + .setSrc(srcNode.getKey) + } + + def children(node: AstNode) = + (0 to node.getChildCount) + .map(node.getChild) + .filterNot(_ == null) + .toList + + def getGlobalNamespaceBlockFullName(fileNameOption: Option[String]): String = { + fileNameOption match { + case Some(fileName) => + s"$fileName:${Defines.globalNamespaceName}" + case None => + Defines.globalNamespaceName + } + } + + implicit class NodeBuilderWrapper(nodeBuilder: Node.Builder) { + def addStringProperty(name: Cpg.NodePropertyName, value: String): Node.Builder = { + nodeBuilder.addProperty(newStringProperty(name, value)) + } + def addIntProperty(name: Cpg.NodePropertyName, value: Int): Node.Builder = { + nodeBuilder.addProperty(newIntProperty(name, value)) + } + def addBooleanProperty(name: Cpg.NodePropertyName, value: Boolean): Node.Builder = { + nodeBuilder.addProperty(newBooleanProperty(name, value)) + } + def addStringListProperty(name: Cpg.NodePropertyName, value: List[String]): Node.Builder = { + nodeBuilder.addProperty(newStringListProperty(name, value)) + } + + } + + implicit class CpgStructBuilderWrapper(cpgStructBuilder: CpgStruct.Builder) { + def addEdge(edgeType: Edge.EdgeType, dstNode: Node, srcNode: Node): CpgStruct.Builder = { + cpgStructBuilder.addEdge(newEdge(edgeType, dstNode, srcNode)) + } + } + +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/adapter/CpgAdapter.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/adapter/CpgAdapter.scala new file mode 100644 index 0000000..12a204f --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/adapter/CpgAdapter.scala @@ -0,0 +1,69 @@ +package io.shiftleft.fuzzyc2cpg.adapter + +import io.shiftleft.fuzzyc2cpg.adapter.EdgeKind.EdgeKind +import io.shiftleft.fuzzyc2cpg.adapter.EdgeProperty.EdgeProperty +import io.shiftleft.fuzzyc2cpg.adapter.NodeKind.NodeKind +import io.shiftleft.fuzzyc2cpg.adapter.NodeProperty.NodeProperty +import io.shiftleft.fuzzyc2cpg.ast.AstNode + +object NodeProperty extends Enumeration { + type NodeProperty = Value + val ORDER, ARGUMENT_INDEX, NAME, FULL_NAME, CODE, EVALUATION_STRATEGY, TYPE_FULL_NAME, TYPE_DECL_FULL_NAME, SIGNATURE, + DISPATCH_TYPE, METHOD_FULL_NAME, METHOD_INST_FULL_NAME, IS_EXTERNAL, PARSER_TYPE_NAME, AST_PARENT_TYPE, + AST_PARENT_FULL_NAME, LINE_NUMBER, COLUMN_NUMBER, LINE_NUMBER_END, COLUMN_NUMBER_END, ALIAS_TYPE_FULL_NAME, + INHERITS_FROM_TYPE_FULL_NAME, CANONICAL_NAME = Value +} + +object NodeKind extends Enumeration { + type NodeKind = Value + val METHOD, METHOD_RETURN, METHOD_PARAMETER_IN, METHOD_INST, CALL, LITERAL, IDENTIFIER, JUMP_TARGET, BLOCK, RETURN, + LOCAL, TYPE, TYPE_DECL, MEMBER, NAMESPACE_BLOCK, CONTROL_STRUCTURE, UNKNOWN, FIELD_IDENTIFIER = Value +} + +object EdgeProperty extends Enumeration { + type EdgeProperty = Value + val CFG_EDGE_TYPE = Value +} + +object EdgeKind extends Enumeration { + type EdgeKind = Value + val AST, CFG, REF, CONDITION, ARGUMENT = Value +} + +trait CfgEdgeType +object TrueEdge extends CfgEdgeType { + override def toString: String = "TrueEdge" +} +object FalseEdge extends CfgEdgeType { + override def toString: String = "FalseEdge" +} +object AlwaysEdge extends CfgEdgeType { + override def toString: String = "AlwaysEdge" +} +object CaseEdge extends CfgEdgeType { + override def toString: String = "CaseEdge" +} + +trait CpgAdapter[NodeBuilderType, NodeType, EdgeBuilderType, EdgeType] { + def createNodeBuilder(kind: NodeKind): NodeBuilderType + + def createNode(nodeBuilder: NodeBuilderType): NodeType + + def createNode(nodeBuilder: NodeBuilderType, origAstNode: AstNode): NodeType + + def addNodeProperty(nodeBuilder: NodeBuilderType, property: NodeProperty, value: String): Unit + + def addNodeProperty(nodeBuilder: NodeBuilderType, property: NodeProperty, value: Int): Unit + + def addNodeProperty(nodeBuilder: NodeBuilderType, property: NodeProperty, value: Boolean): Unit + + def addNodeProperty(nodeBuilder: NodeBuilderType, property: NodeProperty, value: List[String]): Unit + + def createEdgeBuilder(dst: NodeType, src: NodeType, edgeKind: EdgeKind): EdgeBuilderType + + def createEdge(edgeBuilder: EdgeBuilderType): EdgeType + + def addEdgeProperty(edgeBuilder: EdgeBuilderType, property: EdgeProperty, value: String): Unit + + def mapNode(astNode: AstNode): NodeType +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/adapter/ProtoCpgAdapter.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/adapter/ProtoCpgAdapter.scala new file mode 100644 index 0000000..3783692 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/adapter/ProtoCpgAdapter.scala @@ -0,0 +1,144 @@ +package io.shiftleft.fuzzyc2cpg.adapter + +import io.shiftleft.fuzzyc2cpg.Utils._ +import io.shiftleft.fuzzyc2cpg.adapter.EdgeKind.EdgeKind +import io.shiftleft.fuzzyc2cpg.adapter.EdgeProperty.EdgeProperty +import io.shiftleft.fuzzyc2cpg.adapter.NodeKind.NodeKind +import io.shiftleft.fuzzyc2cpg.adapter.NodeProperty.NodeProperty +import io.shiftleft.fuzzyc2cpg.ast.AstNode +import io.shiftleft.passes.KeyPool +import io.shiftleft.proto.cpg.Cpg.CpgStruct.{Edge, Node} +import io.shiftleft.proto.cpg.Cpg.{CpgStruct, NodePropertyName} + +class ProtoCpgAdapter(targetCpg: CpgStruct.Builder, keyPool: KeyPool) + extends CpgAdapter[Node.Builder, Node, Edge.Builder, Edge] { + private var astToProtoMapping = Map.empty[AstNode, Node] + + override def createNodeBuilder(kind: NodeKind): Node.Builder = { + Node.newBuilder().setType(translateNodeKind(kind)).setKey(keyPool.next) + } + + override def createNode(nodeBuilder: Node.Builder): Node = { + val node = nodeBuilder.build + + targetCpg.addNode(nodeBuilder) + + node + } + + override def createNode(nodeBuilder: Node.Builder, origAstNode: AstNode): Node = { + val node = createNode(nodeBuilder) + + astToProtoMapping += origAstNode -> node + + node + } + + override def addNodeProperty(nodeBuilder: Node.Builder, property: NodeProperty, value: String): Unit = { + nodeBuilder.addStringProperty(translateNodeProperty(property), value) + } + + override def addNodeProperty(nodeBuilder: Node.Builder, property: NodeProperty, value: Int): Unit = { + nodeBuilder.addIntProperty(translateNodeProperty(property), value) + } + + override def addNodeProperty(nodeBuilder: Node.Builder, property: NodeProperty, value: Boolean): Unit = { + nodeBuilder.addBooleanProperty(translateNodeProperty(property), value) + } + + override def addNodeProperty(nodeBuilder: Node.Builder, property: NodeProperty, value: List[String]): Unit = { + nodeBuilder.addStringListProperty(translateNodeProperty(property), value) + } + + override def createEdgeBuilder(dst: Node, src: Node, edgeKind: EdgeKind): Edge.Builder = { + Edge + .newBuilder() + .setType(translateEdgeKind(edgeKind)) + .setDst(dst.getKey) + .setSrc(src.getKey) + } + + override def createEdge(edgeBuilder: Edge.Builder): Edge = { + val edge = edgeBuilder.build + + targetCpg.addEdge(edge) + + edge + } + + override def addEdgeProperty(edgeBuilder: Edge.Builder, property: EdgeProperty, value: String): Unit = { + if (property != EdgeProperty.CFG_EDGE_TYPE) { + throw new RuntimeException("Not yet implemented.") + } + } + + override def mapNode(astNode: AstNode): Node = { + astToProtoMapping(astNode) + } + + private def translateNodeProperty(nodeProperty: NodeProperty): NodePropertyName = { + nodeProperty match { + case NodeProperty.ORDER => NodePropertyName.ORDER + case NodeProperty.ARGUMENT_INDEX => NodePropertyName.ARGUMENT_INDEX + case NodeProperty.NAME => NodePropertyName.NAME + case NodeProperty.FULL_NAME => NodePropertyName.FULL_NAME + case NodeProperty.CODE => NodePropertyName.CODE + case NodeProperty.EVALUATION_STRATEGY => + NodePropertyName.EVALUATION_STRATEGY + case NodeProperty.TYPE_FULL_NAME => NodePropertyName.TYPE_FULL_NAME + case NodeProperty.TYPE_DECL_FULL_NAME => + NodePropertyName.TYPE_DECL_FULL_NAME + case NodeProperty.SIGNATURE => NodePropertyName.SIGNATURE + case NodeProperty.DISPATCH_TYPE => NodePropertyName.DISPATCH_TYPE + case NodeProperty.METHOD_FULL_NAME => NodePropertyName.METHOD_FULL_NAME + case NodeProperty.METHOD_INST_FULL_NAME => + NodePropertyName.METHOD_INST_FULL_NAME + case NodeProperty.IS_EXTERNAL => NodePropertyName.IS_EXTERNAL + case NodeProperty.PARSER_TYPE_NAME => NodePropertyName.PARSER_TYPE_NAME + case NodeProperty.AST_PARENT_TYPE => NodePropertyName.AST_PARENT_TYPE + case NodeProperty.AST_PARENT_FULL_NAME => + NodePropertyName.AST_PARENT_FULL_NAME + case NodeProperty.LINE_NUMBER => NodePropertyName.LINE_NUMBER + case NodeProperty.COLUMN_NUMBER => NodePropertyName.COLUMN_NUMBER + case NodeProperty.LINE_NUMBER_END => NodePropertyName.LINE_NUMBER_END + case NodeProperty.COLUMN_NUMBER_END => NodePropertyName.COLUMN_NUMBER_END + case NodeProperty.ALIAS_TYPE_FULL_NAME => NodePropertyName.ALIAS_TYPE_FULL_NAME + case NodeProperty.INHERITS_FROM_TYPE_FULL_NAME => NodePropertyName.INHERITS_FROM_TYPE_FULL_NAME + case NodeProperty.CANONICAL_NAME => NodePropertyName.CANONICAL_NAME + } + } + + private def translateNodeKind(nodeKind: NodeKind): Node.NodeType = { + nodeKind match { + case NodeKind.METHOD => Node.NodeType.METHOD + case NodeKind.METHOD_RETURN => Node.NodeType.METHOD_RETURN + case NodeKind.METHOD_PARAMETER_IN => Node.NodeType.METHOD_PARAMETER_IN + case NodeKind.METHOD_INST => Node.NodeType.METHOD_INST + case NodeKind.CALL => Node.NodeType.CALL + case NodeKind.LITERAL => Node.NodeType.LITERAL + case NodeKind.IDENTIFIER => Node.NodeType.IDENTIFIER + case NodeKind.BLOCK => Node.NodeType.BLOCK + case NodeKind.RETURN => Node.NodeType.RETURN + case NodeKind.LOCAL => Node.NodeType.LOCAL + case NodeKind.TYPE => Node.NodeType.TYPE + case NodeKind.TYPE_DECL => Node.NodeType.TYPE_DECL + case NodeKind.MEMBER => Node.NodeType.MEMBER + case NodeKind.NAMESPACE_BLOCK => Node.NodeType.NAMESPACE_BLOCK + case NodeKind.CONTROL_STRUCTURE => Node.NodeType.CONTROL_STRUCTURE + case NodeKind.UNKNOWN => Node.NodeType.UNKNOWN + case NodeKind.FIELD_IDENTIFIER => Node.NodeType.FIELD_IDENTIFIER + case NodeKind.JUMP_TARGET => Node.NodeType.JUMP_TARGET + } + } + + private def translateEdgeKind(edgeKind: EdgeKind): Edge.EdgeType = { + edgeKind match { + case EdgeKind.AST => Edge.EdgeType.AST + case EdgeKind.CFG => Edge.EdgeType.CFG + case EdgeKind.REF => Edge.EdgeType.REF + case EdgeKind.CONDITION => Edge.EdgeType.CONDITION + case EdgeKind.ARGUMENT => Edge.EdgeType.ARGUMENT + } + } + +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/astnew/AstToCpgConverter.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/astnew/AstToCpgConverter.scala new file mode 100644 index 0000000..2af78ad --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/astnew/AstToCpgConverter.scala @@ -0,0 +1,912 @@ +package io.shiftleft.fuzzyc2cpg.astnew + +import scala.jdk.CollectionConverters._ +import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, Operators} +import io.shiftleft.fuzzyc2cpg.{Defines, Global} +import io.shiftleft.fuzzyc2cpg.adapter.{CpgAdapter, EdgeKind, NodeKind, NodeProperty} +import io.shiftleft.fuzzyc2cpg.adapter.NodeProperty.NodeProperty +import io.shiftleft.fuzzyc2cpg.ast.AstNode +import io.shiftleft.fuzzyc2cpg.ast.declarations.{ClassDefStatement, IdentifierDecl} +import io.shiftleft.fuzzyc2cpg.ast.expressions._ +import io.shiftleft.fuzzyc2cpg.ast.functionDef.{FunctionDefBase, Template} +import io.shiftleft.fuzzyc2cpg.ast.langc.expressions.{CallExpression, SizeofExpression} +import io.shiftleft.fuzzyc2cpg.ast.langc.functiondef.Parameter +import io.shiftleft.fuzzyc2cpg.ast.langc.statements.blockstarters.IfStatement +import io.shiftleft.fuzzyc2cpg.ast.logical.statements.{BlockStarter, CompoundStatement, Label, Statement} +import io.shiftleft.fuzzyc2cpg.ast.statements.blockstarters.CatchList +import io.shiftleft.fuzzyc2cpg.ast.statements.jump._ +import io.shiftleft.fuzzyc2cpg.ast.statements.{ExpressionStatement, IdentifierDeclStatement} +import io.shiftleft.fuzzyc2cpg.ast.walking.ASTNodeVisitor +import io.shiftleft.fuzzyc2cpg.scope.Scope +import io.shiftleft.proto.cpg.Cpg.DispatchTypes +import org.slf4j.LoggerFactory + +object AstToCpgConverter { + private val logger = LoggerFactory.getLogger(getClass) +} + +class AstToCpgConverter[NodeBuilderType, NodeType, EdgeBuilderType, EdgeType]( + cpgParent: NodeType, + adapter: CpgAdapter[NodeBuilderType, NodeType, EdgeBuilderType, EdgeType], + global: Global) + extends ASTNodeVisitor { + import AstToCpgConverter._ + + private var contextStack = List[Context]() + private val scope = new Scope[String, (NodeType, String), NodeType]() + private var methodNode = Option.empty[NodeType] + private var methodReturnNode = Option.empty[NodeType] + + pushContext(cpgParent, 1) + + private class Context(val cpgParent: NodeType, + var childNum: Int, + val parentIsClassDef: Boolean, + val parentIsMemberAccess: Boolean = false, + var addConditionEdgeOnNextAstEdge: Boolean = false, + var addArgumentEdgeOnNextAstEdge: Boolean = false) {} + + private def pushContext(cpgParent: NodeType, + startChildNum: Int, + parentIsClassDef: Boolean = false, + parentIsMemberAccess: Boolean = false): Unit = { + contextStack = new Context(cpgParent, startChildNum, parentIsClassDef, parentIsMemberAccess) :: contextStack + } + + private def popContext(): Unit = { + contextStack = contextStack.tail + } + + private def context: Context = { + contextStack.head + } + + private implicit class NodeBuilderWrapper(nodeBuilder: NodeBuilderType) { + def addProperty(property: NodeProperty, value: String): NodeBuilderType = { + adapter.addNodeProperty(nodeBuilder, property, value) + nodeBuilder + } + + def addProperty(property: NodeProperty, value: Option[Int]): NodeBuilderType = { + value.foreach(adapter.addNodeProperty(nodeBuilder, property, _)) + nodeBuilder + } + + def addProperty(property: NodeProperty, value: Int): NodeBuilderType = { + adapter.addNodeProperty(nodeBuilder, property, value) + nodeBuilder + } + def addProperty(property: NodeProperty, value: Boolean): NodeBuilderType = { + adapter.addNodeProperty(nodeBuilder, property, value) + nodeBuilder + } + def addProperty(property: NodeProperty, value: List[String]): NodeBuilderType = { + adapter.addNodeProperty(nodeBuilder, property, value) + nodeBuilder + } + def createNode(astNode: AstNode): NodeType = { + adapter.createNode(nodeBuilder, astNode) + } + def createNode(): NodeType = { + adapter.createNode(nodeBuilder) + } + def addCommons(astNode: AstNode, context: Context): NodeBuilderType = { + nodeBuilder + .addProperty(NodeProperty.CODE, astNode.getEscapedCodeStr) + .addProperty(NodeProperty.ORDER, context.childNum) + .addProperty(NodeProperty.ARGUMENT_INDEX, context.childNum) + .addProperty(NodeProperty.LINE_NUMBER, astNode.getLocation.startLine) + .addProperty(NodeProperty.COLUMN_NUMBER, astNode.getLocation.startPos) + } + } + + private implicit class EdgeBuilderWrapper(edgeBuilder: EdgeBuilderType) { + def createEdge(): EdgeType = { + adapter.createEdge(edgeBuilder) + } + } + + def getMethodNode: Option[NodeType] = { + methodNode + } + + def getMethodReturnNode: Option[NodeType] = { + methodReturnNode + } + + def convert(astNode: AstNode): Unit = { + astNode.accept(this) + } + + override def visit(astFunction: FunctionDefBase): Unit = { + val returnType = if (astFunction.getReturnType != null) { + astFunction.getReturnType.getEscapedCodeStr + } else { + "int" + } + val signature = new StringBuilder() + .append(returnType) + .append("(") + .append(astFunction.getParameterList.getEscapedCodeStr(false)) + .append(")") + .toString() + + val cpgMethod = adapter + .createNodeBuilder(NodeKind.METHOD) + .addProperty(NodeProperty.NAME, astFunction.getName) + .addProperty(NodeProperty.CODE, astFunction.getEscapedCodeStr) + .addProperty(NodeProperty.IS_EXTERNAL, value = false) + .addProperty(NodeProperty.FULL_NAME, value = s"${astFunction.getName}") + .addProperty(NodeProperty.LINE_NUMBER, astFunction.getLocation.startLine) + .addProperty(NodeProperty.COLUMN_NUMBER, astFunction.getLocation.startPos) + .addProperty(NodeProperty.LINE_NUMBER_END, astFunction.getLocation.endLine) + .addProperty(NodeProperty.COLUMN_NUMBER_END, astFunction.getLocation.endPos) + .addProperty(NodeProperty.SIGNATURE, signature) + .createNode(astFunction) + + methodNode = Some(cpgMethod) + + addAstChild(cpgMethod) + + pushContext(cpgMethod, 1) + scope.pushNewScope(cpgMethod) + + astFunction.getParameterList.asScala.foreach { parameter => + parameter.accept(this) + } + + val templateParamList = astFunction.getTemplateParameterList + if (templateParamList != null) { + templateParamList.asScala.foreach { template => + template.accept(this) + } + } + + val methodReturnLocation = + if (astFunction.getReturnType != null) { + astFunction.getReturnType.getLocation + } else { + astFunction.getLocation + } + val cpgMethodReturn = adapter + .createNodeBuilder(NodeKind.METHOD_RETURN) + .addProperty(NodeProperty.CODE, "RET") + .addProperty(NodeProperty.EVALUATION_STRATEGY, EvaluationStrategies.BY_VALUE) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(returnType)) + .addProperty(NodeProperty.LINE_NUMBER, methodReturnLocation.startLine) + .addProperty(NodeProperty.COLUMN_NUMBER, methodReturnLocation.startPos) + .addProperty(NodeProperty.ORDER, context.childNum) + .createNode() + + methodReturnNode = Some(cpgMethodReturn) + + addAstChild(cpgMethodReturn) + astFunction.getContent.accept(this) + + scope.popScope() + popContext() + } + + override def visit(astParameter: Parameter): Unit = { + val parameterType = if (astParameter.getType != null) { + astParameter.getType.getEscapedCodeStr + } else { + "int" + } + + val cpgParameter = adapter + .createNodeBuilder(NodeKind.METHOD_PARAMETER_IN) + .addProperty(NodeProperty.CODE, astParameter.getEscapedCodeStr) + .addProperty(NodeProperty.NAME, astParameter.getName) + .addProperty(NodeProperty.ORDER, astParameter.getChildNumber + 1) + .addProperty(NodeProperty.EVALUATION_STRATEGY, EvaluationStrategies.BY_VALUE) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(parameterType)) + .addProperty(NodeProperty.LINE_NUMBER, astParameter.getLocation.startLine) + .addProperty(NodeProperty.COLUMN_NUMBER, astParameter.getLocation.startPos) + .createNode(astParameter) + + scope.addToScope(astParameter.getName, (cpgParameter, parameterType)) + addAstChild(cpgParameter) + } + + override def visit(template: Template): Unit = { + // TODO (#60): Populate templated types in CPG. + logger.debug("NYI: Template parsing.") + } + + override def visit(argument: Argument): Unit = { + argument.getExpression.accept(this) + } + + override def visit(argumentList: ArgumentList): Unit = { + acceptChildren(argumentList, withArgEdges = true) + } + + override def visit(astAssignment: AssignmentExpression): Unit = { + val operatorMethod = astAssignment.getOperator match { + case "=" => Operators.assignment + case "*=" => Operators.assignmentMultiplication + case "/=" => Operators.assignmentDivision + case "%=" => Operators.assignmentDivision + case "+=" => Operators.assignmentPlus + case "-=" => Operators.assignmentMinus + case "<<=" => Operators.assignmentShiftLeft + case ">>=" => Operators.assignmentArithmeticShiftRight + case "&=" => Operators.assignmentAnd + case "^=" => Operators.assignmentXor + case "|=" => Operators.assignmentOr + } + visitBinaryExpr(astAssignment, operatorMethod) + } + + override def visit(astAdd: AdditiveExpression): Unit = { + val operatorMethod = astAdd.getOperator match { + case "+" => Operators.addition + case "-" => Operators.subtraction + } + + visitBinaryExpr(astAdd, operatorMethod) + } + + override def visit(astMult: MultiplicativeExpression): Unit = { + val operatorMethod = astMult.getOperator match { + case "*" => Operators.multiplication + case "/" => Operators.division + case "%" => Operators.modulo + } + + visitBinaryExpr(astMult, operatorMethod) + } + + override def visit(astRelation: RelationalExpression): Unit = { + val operatorMethod = astRelation.getOperator match { + case "<" => Operators.lessThan + case ">" => Operators.greaterThan + case "<=" => Operators.lessEqualsThan + case ">=" => Operators.greaterEqualsThan + } + + visitBinaryExpr(astRelation, operatorMethod) + } + + override def visit(astShift: ShiftExpression): Unit = { + val operatorMethod = astShift.getOperator match { + case "<<" => Operators.shiftLeft + case ">>" => Operators.arithmeticShiftRight + } + + visitBinaryExpr(astShift, operatorMethod) + } + + override def visit(astEquality: EqualityExpression): Unit = { + val operatorMethod = astEquality.getOperator match { + case "==" => Operators.equals + case "!=" => Operators.notEquals + } + + visitBinaryExpr(astEquality, operatorMethod) + } + + override def visit(astBitAnd: BitAndExpression): Unit = { + visitBinaryExpr(astBitAnd, Operators.and) + } + + override def visit(astInclOr: InclusiveOrExpression): Unit = { + visitBinaryExpr(astInclOr, Operators.or) + } + + override def visit(astExclOr: ExclusiveOrExpression): Unit = { + visitBinaryExpr(astExclOr, Operators.or) + } + + override def visit(astOr: OrExpression): Unit = { + visitBinaryExpr(astOr, Operators.logicalOr) + } + + override def visit(astAnd: AndExpression): Unit = { + visitBinaryExpr(astAnd, Operators.logicalAnd) + } + + override def visit(astUnary: UnaryExpression): Unit = { + Option(astUnary.getChild(0)) match { + case Some(_) => + val operatorMethod = astUnary.getChild(0).getEscapedCodeStr match { + case "+" => Operators.plus + case "-" => Operators.minus + case "*" => Operators.indirection + case "&" => Operators.addressOf + case "~" => Operators.not + case "!" => Operators.logicalNot + case "++" => Operators.preIncrement + case "--" => Operators.preDecrement + } + + val cpgUnary = createCallNode(astUnary, operatorMethod) + + addAstChild(cpgUnary) + + pushContext(cpgUnary, 1) + context.addArgumentEdgeOnNextAstEdge = true + astUnary.getChild(1).accept(this) + popContext() + case None => + // We get here for `new` expression. + val cpgNew = newUnknownNode(astUnary) + + addAstChild(cpgNew) + } + } + + override def visit(astPostIncDecOp: PostIncDecOperationExpression): Unit = { + val operatorMethod = astPostIncDecOp.getChild(1).getEscapedCodeStr match { + case "++" => Operators.postIncrement + case "--" => Operators.postDecrement + } + + val cpgPostIncDecOp = createCallNode(astPostIncDecOp, operatorMethod) + + addAstChild(cpgPostIncDecOp) + + pushContext(cpgPostIncDecOp, 1) + context.addArgumentEdgeOnNextAstEdge = true + astPostIncDecOp.getChild(0).accept(this) + popContext() + } + + override def visit(astCall: CallExpression): Unit = { + val targetMethodName = astCall.getChild(0).getEscapedCodeStr + // TODO the DISPATCH_TYPE needs to depend on the type of the identifier which is "called". + // At the moment we use STATIC_DISPATCH also for calls of function pointers. + // When this is done we need to draw a RECEIVER edge for DYNAMIC_DISPATCH function pointer + // calls to the pointer expression. + val cpgCall = createCallNode(astCall, targetMethodName) + + addAstChild(cpgCall) + + pushContext(cpgCall, 1) + // Argument edges are added when visiting each individual argument. + astCall.getArgumentList.accept(this) + popContext() + } + + override def visit(astNew: NewExpression): Unit = { + val call = createCallNode(astNew, ".new") + + addAstChild(call) + pushContext(call, 1) + context.addArgumentEdgeOnNextAstEdge = true + astNew.getTargetClass.accept(this) + astNew.getArgumentList.accept(this) + popContext() + } + + override def visit(astDelete: DeleteExpression): Unit = { + val call = createCallNode(astDelete, Operators.delete); + + addAstChild(call) + pushContext(call, 1) + context.addArgumentEdgeOnNextAstEdge = true; + astDelete.getTarget.accept(this) + popContext() + } + + override def visit(astConstant: Constant): Unit = { + val constantType = deriveConstantTypeFromCode(astConstant.getEscapedCodeStr) + val cpgConstant = adapter + .createNodeBuilder(NodeKind.LITERAL) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(constantType)) + .addCommons(astConstant, context) + .createNode(astConstant) + + addAstChild(cpgConstant) + } + + override def visit(astBreak: BreakStatement): Unit = { + addAstChild(newControlStructureNode(astBreak)) + } + + override def visit(astContinue: ContinueStatement): Unit = { + addAstChild(newControlStructureNode(astContinue)) + } + + override def visit(astGoto: GotoStatement): Unit = { + addAstChild(newControlStructureNode(astGoto)) + } + + override def visit(astIdentifier: Identifier): Unit = { + val identifierName = astIdentifier.getEscapedCodeStr + + if (!contextStack.isEmpty && contextStack.head.parentIsMemberAccess && contextStack.head.childNum == 2) { + val cpgFieldIdentifier = adapter + .createNodeBuilder(NodeKind.FIELD_IDENTIFIER) + .addProperty(NodeProperty.CANONICAL_NAME, identifierName) + .addCommons(astIdentifier, context) + .createNode(astIdentifier) + addAstChild(cpgFieldIdentifier) + return + } + + val variableOption = scope.lookupVariable(identifierName) + val identifierTypeName = variableOption match { + case Some((_, variableTypeName)) => + variableTypeName + case None => + Defines.anyTypeName + } + + val cpgIdentifier = adapter + .createNodeBuilder(NodeKind.IDENTIFIER) + .addProperty(NodeProperty.NAME, identifierName) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(identifierTypeName)) + .addCommons(astIdentifier, context) + .createNode(astIdentifier) + + addAstChild(cpgIdentifier) + + variableOption match { + case Some((variable, _)) => + adapter + .createEdgeBuilder(variable, cpgIdentifier, EdgeKind.REF) + .createEdge() + case None => + } + + } + + override def visit(condition: Condition): Unit = { + //not called for ConditionalExpression, cf joern#91 + context.addConditionEdgeOnNextAstEdge = true + condition.getExpression.accept(this) + } + + override def visit(astConditionalExpr: ConditionalExpression): Unit = { + //this ought to be a ControlStructureNode, but we currently cannot handle that in the dataflow tracker + val cpgConditionalExpr = createCallNode(astConditionalExpr, ".conditionalExpression") + addAstChild(cpgConditionalExpr) + val condition = astConditionalExpr.getChild(0).asInstanceOf[Condition] + val trueExpression = astConditionalExpr.getChild(1) + val falseExpression = astConditionalExpr.getChild(2) + // avoid setting context.addConditionEdgeOnNextAstEdge in this.visit(condition), cf joern#91 + pushContext(cpgConditionalExpr, 1) + context.addArgumentEdgeOnNextAstEdge = true + condition.getExpression.accept(this) + context.addArgumentEdgeOnNextAstEdge = true + trueExpression.accept(this) + context.addArgumentEdgeOnNextAstEdge = true + falseExpression.accept(this) + popContext() + } + + override def visit(expression: Expression): Unit = { + // We only end up here for expressions chained by ','. + // Those expressions are than the children of the expression + // given as parameter. + val classOfExpression = expression.getClass + if (classOfExpression != classOf[Expression]) { + throw new RuntimeException( + s"Only direct instances of Expressions expected " + + s"but ${classOfExpression.getSimpleName} found") + } + + val cpgBlock = adapter + .createNodeBuilder(NodeKind.BLOCK) + .addProperty(NodeProperty.CODE, "") + .addProperty(NodeProperty.ORDER, context.childNum) + .addProperty(NodeProperty.ARGUMENT_INDEX, context.childNum) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(Defines.anyTypeName)) + .addProperty(NodeProperty.LINE_NUMBER, expression.getLocation.startLine) + .addProperty(NodeProperty.COLUMN_NUMBER, expression.getLocation.startPos) + .createNode(expression) + + addAstChild(cpgBlock) + + pushContext(cpgBlock, 1) + acceptChildren(expression) + popContext() + } + + override def visit(forInit: ForInit): Unit = { + acceptChildren(forInit) + } + + override def visit(astBlockStarter: BlockStarter): Unit = { + val cpgBlockStarter = newControlStructureNode(astBlockStarter) + addAstChild(cpgBlockStarter) + pushContext(cpgBlockStarter, 1) + + acceptChildren(astBlockStarter) + + popContext() + } + + override def visit(astCatchList: CatchList): Unit = { + val cpgCatchList = newUnknownNode(astCatchList) + addAstChild(cpgCatchList) + + pushContext(cpgCatchList, 1) + astCatchList.asScala.foreach { catchElement => + catchElement.accept(this) + } + popContext() + } + + override def visit(astThrow: ThrowStatement): Unit = { + val cpgThrow = newControlStructureNode(astThrow) + + addAstChild(cpgThrow) + + pushContext(cpgThrow, 1) + val throwExpression = astThrow.getThrowExpression + if (throwExpression != null) { + throwExpression.accept(this) + } + popContext() + } + + override def visit(astIfStmt: IfStatement): Unit = { + val cpgIfStmt = newControlStructureNode(astIfStmt) + addAstChild(cpgIfStmt) + pushContext(cpgIfStmt, 1) + + astIfStmt.getCondition.accept(this) + astIfStmt.getStatement.accept(this) + val astElseStmt = astIfStmt.getElseNode + if (astElseStmt != null) { + astElseStmt.accept(this) + } + popContext() + } + + override def visit(statement: ExpressionStatement): Unit = { + Option(statement.getExpression).foreach(_.accept(this)) + } + + override def visit(astBlock: CompoundStatement): Unit = { + if (context.parentIsClassDef) { + astBlock.getStatements.asScala.foreach { statement => + statement.accept(this) + } + } else { + val cpgBlock = adapter + .createNodeBuilder(NodeKind.BLOCK) + .addProperty(NodeProperty.CODE, "") + .addProperty(NodeProperty.ORDER, context.childNum) + .addProperty(NodeProperty.ARGUMENT_INDEX, context.childNum) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(Defines.voidTypeName)) + .addProperty(NodeProperty.LINE_NUMBER, astBlock.getLocation.startLine) + .addProperty(NodeProperty.COLUMN_NUMBER, astBlock.getLocation.startPos) + .createNode(astBlock) + + addAstChild(cpgBlock) + + pushContext(cpgBlock, 1) + scope.pushNewScope(cpgBlock) + astBlock.getStatements.asScala.foreach { statement => + statement.accept(this) + } + popContext() + scope.popScope() + } + } + + override def visit(astReturn: ReturnStatement): Unit = { + val cpgReturn = adapter + .createNodeBuilder(NodeKind.RETURN) + .addCommons(astReturn, context) + .createNode(astReturn) + + addAstChild(cpgReturn) + + pushContext(cpgReturn, 1) + Option(astReturn.getReturnExpression).foreach { returnExpr => + context.addArgumentEdgeOnNextAstEdge = true + returnExpr.accept(this) + } + popContext() + } + + override def visit(astIdentifierDeclStmt: IdentifierDeclStatement): Unit = { + astIdentifierDeclStmt.getIdentifierDeclList.asScala.foreach { identifierDecl => + identifierDecl.accept(this) + } + } + + override def visit(identifierDecl: IdentifierDecl): Unit = { + val declTypeName = identifierDecl.getType.getEscapedCodeStr + + if (identifierDecl.isTypedef) { + val aliasTypeDecl = adapter + .createNodeBuilder(NodeKind.TYPE_DECL) + .addProperty(NodeProperty.NAME, identifierDecl.getName.getEscapedCodeStr) + .addProperty(NodeProperty.FULL_NAME, identifierDecl.getName.getEscapedCodeStr) + .addProperty(NodeProperty.IS_EXTERNAL, value = false) + .addProperty(NodeProperty.ALIAS_TYPE_FULL_NAME, registerType(declTypeName)) + .createNode(identifierDecl) + + addAstChild(aliasTypeDecl) + } else if (context.parentIsClassDef) { + val cpgMember = adapter + .createNodeBuilder(NodeKind.MEMBER) + .addProperty(NodeProperty.CODE, identifierDecl.getEscapedCodeStr) + .addProperty(NodeProperty.NAME, identifierDecl.getName.getEscapedCodeStr) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(declTypeName)) + .createNode(identifierDecl) + addAstChild(cpgMember) + } else { + // We only process file level identifier declarations if they are typedefs. + // Everything else is ignored. + if (!scope.isEmpty) { + val localName = identifierDecl.getName.getEscapedCodeStr + val cpgLocal = adapter + .createNodeBuilder(NodeKind.LOCAL) + .addProperty(NodeProperty.CODE, localName) + .addProperty(NodeProperty.NAME, localName) + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(declTypeName)) + .addProperty(NodeProperty.ORDER, context.childNum) + .createNode(identifierDecl) + + val scopeParentNode = + scope.addToScope(localName, (cpgLocal, declTypeName)) + // Here we on purpose do not use addAstChild because the LOCAL nodes + // are not really in the AST (they also have no ORDER property). + // So do not be confused that the format still demands an AST edge. + adapter + .createEdgeBuilder(cpgLocal, scopeParentNode, EdgeKind.AST) + .createEdge() + + val assignmentExpression = identifierDecl.getAssignment + if (assignmentExpression != null) { + assignmentExpression.accept(this) + } + } + } + } + + override def visit(astSizeof: SizeofExpression): Unit = { + val cpgSizeof = createCallNode(astSizeof, Operators.sizeOf) + + addAstChild(cpgSizeof) + + pushContext(cpgSizeof, 1) + // Child 0 is just the keyword 'sizeof' which at this point is duplicate + // information for us. + context.addArgumentEdgeOnNextAstEdge = true + astSizeof.getChild(1).accept(this) + popContext() + } + + override def visit(astSizeofOperand: SizeofOperand): Unit = { + astSizeofOperand.getChildCount match { + case 0 => + // Operand is a type. + val cpgTypeRef = newUnknownNode(astSizeofOperand) + addAstChild(cpgTypeRef) + case 1 => + // Operand is an expression. + astSizeofOperand.getChild(1).accept(this) + } + } + + override def visit(astLabel: Label): Unit = { + val cpgLabel = adapter + .createNodeBuilder(NodeKind.JUMP_TARGET) + .addProperty(NodeProperty.PARSER_TYPE_NAME, astLabel.getClass.getSimpleName) + .addProperty(NodeProperty.NAME, astLabel.getLabelName) + .addProperty(NodeProperty.CODE, astLabel.getEscapedCodeStr) + .addCommons(astLabel, context) + .createNode(astLabel) + addAstChild(cpgLabel) + } + + override def visit(astArrayIndexing: ArrayIndexing): Unit = { + val cpgArrayIndexing = + createCallNode(astArrayIndexing, Operators.indirectIndexAccess) + + addAstChild(cpgArrayIndexing) + + pushContext(cpgArrayIndexing, 1) + context.addArgumentEdgeOnNextAstEdge = true + astArrayIndexing.getArrayExpression.accept(this) + context.addArgumentEdgeOnNextAstEdge = true + astArrayIndexing.getIndexExpression.accept(this) + popContext() + } + + override def visit(astCast: CastExpression): Unit = { + val cpgCast = createCallNode(astCast, Operators.cast) + + addAstChild(cpgCast) + + pushContext(cpgCast, 1) + context.addArgumentEdgeOnNextAstEdge = true + astCast.getCastTarget.accept(this) + context.addArgumentEdgeOnNextAstEdge = true + astCast.getCastExpression.accept(this) + popContext() + } + + override def visit(astMemberAccess: MemberAccess): Unit = { + val cpgMemberAccess = + createCallNode(astMemberAccess, Operators.fieldAccess) + + addAstChild(cpgMemberAccess) + + pushContext(cpgMemberAccess, 1, parentIsMemberAccess = true) + acceptChildren(astMemberAccess, withArgEdges = true) + popContext() + } + + override def visit(astPtrMemberAccess: PtrMemberAccess): Unit = { + val cpgPtrMemberAccess = + createCallNode(astPtrMemberAccess, Operators.indirectFieldAccess) + + addAstChild(cpgPtrMemberAccess) + + pushContext(cpgPtrMemberAccess, 1, parentIsMemberAccess = true) + acceptChildren(astPtrMemberAccess, withArgEdges = true) + popContext() + } + + override def visit(astCastTarget: CastTarget): Unit = { + val cpgCastTarget = newUnknownNode(astCastTarget) + addAstChild(cpgCastTarget) + } + + override def visit(astInitializerList: InitializerList): Unit = { + // TODO figure out how to represent. + } + + override def visit(statement: Statement): Unit = { + if (statement.getChildCount != 0) { + throw new RuntimeException("Unhandled statement type: " + statement.getClass) + } else { + logger.debug("Parse error. Code: {}", statement.getEscapedCodeStr) + } + } + + override def visit(astClassDef: ClassDefStatement): Unit = { + // TODO: currently NAME and FULL_NAME are the same, since + // the parser does not detect C++ namespaces. Change that, + // once the parser handles namespaces. + var name = astClassDef.identifier.toString + name = name.substring(1, name.length - 1) + val baseClassList = astClassDef.baseClasses.asScala.map { identifier => + val baseClassName = identifier.toString + baseClassName.substring(1, baseClassName.length - 1) + }.toList + + val cpgTypeDeclBuilder = adapter + .createNodeBuilder(NodeKind.TYPE_DECL) + .addProperty(NodeProperty.NAME, name) + .addProperty(NodeProperty.FULL_NAME, name) + .addProperty(NodeProperty.IS_EXTERNAL, value = false) + if (!baseClassList.isEmpty) { + cpgTypeDeclBuilder.addProperty(NodeProperty.INHERITS_FROM_TYPE_FULL_NAME, baseClassList) + baseClassList.map { registerType(_) } + } + val cpgTypeDecl = cpgTypeDeclBuilder.createNode(astClassDef) + + addAstChild(cpgTypeDecl) + + val templateParamList = astClassDef.getTemplateParameterList + if (templateParamList != null) { + templateParamList.asScala.foreach { template => + template.accept(this) + } + } + + pushContext(cpgTypeDecl, 1, parentIsClassDef = true) + astClassDef.content.accept(this) + popContext() + } + + private def visitBinaryExpr(astBinaryExpr: BinaryExpression, operatorMethod: String): Unit = { + val cpgBinaryExpr = createCallNode(astBinaryExpr, operatorMethod) + + addAstChild(cpgBinaryExpr) + + pushContext(cpgBinaryExpr, 1) + + context.addArgumentEdgeOnNextAstEdge = true + astBinaryExpr.getLeft.accept(this) + + context.addArgumentEdgeOnNextAstEdge = true + astBinaryExpr.getRight.accept(this) + + popContext() + } + + private def addAstChild(child: NodeType): Unit = { + adapter + .createEdgeBuilder(child, context.cpgParent, EdgeKind.AST) + .createEdge() + + context.childNum += 1 + + if (context.addConditionEdgeOnNextAstEdge) { + addConditionChild(child) + context.addConditionEdgeOnNextAstEdge = false + } + + if (context.addArgumentEdgeOnNextAstEdge) { + addArgumentChild(child) + context.addArgumentEdgeOnNextAstEdge = false + } + } + + private def addConditionChild(child: NodeType): Unit = { + adapter + .createEdgeBuilder(child, context.cpgParent, EdgeKind.CONDITION) + .createEdge() + } + + private def addArgumentChild(child: NodeType): Unit = { + adapter + .createEdgeBuilder(child, context.cpgParent, EdgeKind.ARGUMENT) + .createEdge() + } + + private def newUnknownNode(astNode: AstNode): NodeType = { + adapter + .createNodeBuilder(NodeKind.UNKNOWN) + .addProperty(NodeProperty.PARSER_TYPE_NAME, astNode.getClass.getSimpleName) + .addCommons(astNode, context) + .createNode(astNode) + } + + private def newControlStructureNode(astNode: AstNode): NodeType = { + adapter + .createNodeBuilder(NodeKind.CONTROL_STRUCTURE) + .addProperty(NodeProperty.PARSER_TYPE_NAME, astNode.getClass.getSimpleName) + .addCommons(astNode, context) + .createNode(astNode) + } + + private def createCallNode(astNode: AstNode, methodName: String): NodeType = { + val cpgNode = adapter + .createNodeBuilder(NodeKind.CALL) + .addProperty(NodeProperty.NAME, methodName) + .addProperty(NodeProperty.DISPATCH_TYPE, DispatchTypes.STATIC_DISPATCH.name()) + .addProperty(NodeProperty.SIGNATURE, "TODO assignment signature") + .addProperty(NodeProperty.TYPE_FULL_NAME, registerType(Defines.anyTypeName)) + .addProperty(NodeProperty.METHOD_FULL_NAME, methodName) + .addCommons(astNode, context) + .createNode(astNode) + + cpgNode + } + + private def acceptChildren(node: AstNode, withArgEdges: Boolean = false): Unit = { + node.getChildIterator.forEachRemaining { child => + context.addArgumentEdgeOnNextAstEdge = withArgEdges + child.accept(this) + } + } + + private def registerType(typeName: String): String = { + global.usedTypes.put(typeName, true) + typeName + } + + // TODO Implement this method properly, the current implementation is just a + // quick hack to have some implementation at all. + private def deriveConstantTypeFromCode(code: String): String = { + val firstChar = code.charAt(0) + val lastChar = code.charAt(code.length - 1) + if (firstChar == '"') { + Defines.charPointerTypeName + } else if (firstChar == '\'') { + Defines.charTypeName + } else if (lastChar == 'f' || lastChar == 'F') { + Defines.floatTypeName + } else if (lastChar == 'd' || lastChar == 'D') { + Defines.doubleTypeName + } else if (lastChar == 'l' || lastChar == 'L') { + Defines.longTypeName + } else if (code.endsWith("ll") || code.endsWith("LL")) { + Defines.longlongTypeName + } else { + Defines.intTypeName + } + } +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/cfg/AstToCfgConverter.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/cfg/AstToCfgConverter.scala new file mode 100644 index 0000000..3691c87 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/cfg/AstToCfgConverter.scala @@ -0,0 +1,561 @@ +package io.shiftleft.fuzzyc2cpg.cfg + +import io.shiftleft.fuzzyc2cpg.adapter.EdgeProperty.EdgeProperty +import io.shiftleft.fuzzyc2cpg.adapter.{ + AlwaysEdge, + CaseEdge, + CfgEdgeType, + CpgAdapter, + EdgeKind, + EdgeProperty, + FalseEdge, + TrueEdge +} +import io.shiftleft.fuzzyc2cpg.ast.AstNode +import io.shiftleft.fuzzyc2cpg.ast.declarations.{ClassDefStatement, IdentifierDecl} +import io.shiftleft.fuzzyc2cpg.ast.expressions._ +import io.shiftleft.fuzzyc2cpg.ast.langc.expressions.SizeofExpression +import io.shiftleft.fuzzyc2cpg.ast.langc.functiondef.FunctionDef +import io.shiftleft.fuzzyc2cpg.ast.langc.statements.blockstarters.{ElseStatement, IfStatement} +import io.shiftleft.fuzzyc2cpg.ast.logical.statements.{CompoundStatement, Label, Statement} +import io.shiftleft.fuzzyc2cpg.ast.statements.{ExpressionStatement, IdentifierDeclStatement} +import io.shiftleft.fuzzyc2cpg.ast.statements.blockstarters._ +import io.shiftleft.fuzzyc2cpg.ast.statements.jump._ +import io.shiftleft.fuzzyc2cpg.ast.walking.ASTNodeVisitor +import org.slf4j.LoggerFactory + +import scala.jdk.CollectionConverters._ + +object AstToCfgConverter { + private val logger = LoggerFactory.getLogger(getClass) +} + +class AstToCfgConverter[NodeType, EdgeBuilderType, EdgeType]( + entryNode: NodeType, + exitNode: NodeType, + adapter: CpgAdapter[_, NodeType, EdgeBuilderType, EdgeType] = null) + extends ASTNodeVisitor { + import AstToCfgConverter._ + + private case class FringeElement(node: NodeType, cfgEdgeType: CfgEdgeType) + + private implicit class FringeWrapper(fringe: List[FringeElement]) { + def setCfgEdgeType(cfgEdgeType: CfgEdgeType): List[FringeElement] = { + fringe.map { + case FringeElement(node, _) => + FringeElement(node, cfgEdgeType) + } + } + + def add(node: NodeType, cfgEdgeType: CfgEdgeType): List[FringeElement] = { + FringeElement(node, cfgEdgeType) :: fringe + } + + def add(nodes: List[NodeType], cfgEdgeType: CfgEdgeType): List[FringeElement] = { + nodes.map(node => FringeElement(node, cfgEdgeType)) ++ fringe + } + + def add(otherFringe: List[FringeElement]): List[FringeElement] = { + otherFringe ++ fringe + } + } + + private implicit class EdgeBuilderWrapper2(edgeBuilder: EdgeBuilderType) { + def addProperty(property: EdgeProperty, value: String): EdgeBuilderType = { + adapter.addEdgeProperty(edgeBuilder, property, value) + edgeBuilder + } + def createEdge(): EdgeType = { + adapter.createEdge(edgeBuilder) + } + } + + private def extendCfg(astDstNode: AstNode): Unit = { + val dstNode = adapter.mapNode(astDstNode) + extendCfg(dstNode) + } + + private def extendCfg(dstNode: NodeType): Unit = { + fringe.foreach { + case FringeElement(srcNode, cfgEdgeType) => + adapter + .createEdgeBuilder(dstNode, srcNode, EdgeKind.CFG) + .addProperty(EdgeProperty.CFG_EDGE_TYPE, cfgEdgeType.toString) + .createEdge() + } + fringe = Nil.add(dstNode, AlwaysEdge) + + if (markerStack.nonEmpty) { + // Up until the first none None stack element we replace the Nones with Some(dstNode) + val leadingNoneLength = markerStack.segmentLength(_.isEmpty, 0) + markerStack = List.fill(leadingNoneLength)(Some(dstNode)) ++ markerStack + .drop(leadingNoneLength) + } + + if (pendingGotoLabels.nonEmpty) { + pendingGotoLabels.foreach { label => + labeledNodes = labeledNodes + (label -> dstNode) + } + pendingGotoLabels = List() + } + + // TODO at the moment we discard the case labels + if (pendingCaseLabels.nonEmpty) { + // Under normal conditions this is always true. + // But if the parser missed a switch statement, caseStack + // might by empty. + if (caseStack.numberOfLayers > 0) { + val containsDefaultLabel = pendingCaseLabels.contains("default") + caseStack.store((dstNode, containsDefaultLabel)) + } + pendingCaseLabels = List() + } + + } + + private var fringe = List[FringeElement]().add(entryNode, AlwaysEdge) + private var markerStack = List[Option[NodeType]]() // Used to track the start of yet to be processed + // cfg parts. + private val breakStack = new LayeredStack[NodeType]() + private val continueStack = new LayeredStack[NodeType]() + private val caseStack = new LayeredStack[(NodeType, Boolean)]() + private var gotos = List[(NodeType, String)]() + private var returns = List[NodeType]() + private var labeledNodes = Map[String, NodeType]() + private var pendingGotoLabels = List[String]() + private var pendingCaseLabels = List[String]() + + private def connectGotosAndLabels(): Unit = { + gotos.foreach { + case (goto, label) => + labeledNodes.get(label) match { + case Some(labeledNode) => + adapter + .createEdgeBuilder(labeledNode, goto, EdgeKind.CFG) + .addProperty(EdgeProperty.CFG_EDGE_TYPE, AlwaysEdge.toString) + .createEdge() + case None => + logger.info("Unable to wire goto statement. Missing label {}.", label) + } + } + } + + private def connectReturnsToExit(): Unit = { + returns.foreach { ret => + adapter + .createEdgeBuilder(exitNode, ret, EdgeKind.CFG) + .addProperty(EdgeProperty.CFG_EDGE_TYPE, AlwaysEdge.toString) + .createEdge() + } + } + + def convert(astNode: AstNode): Unit = { + astNode.accept(this) + extendCfg(exitNode) + connectGotosAndLabels() + connectReturnsToExit() + } + + override def visit(argument: Argument): Unit = { + argument.getExpression.accept(this) + } + + override def visit(argumentList: ArgumentList): Unit = { + acceptChildren(argumentList) + } + + override def visit(arrayIndexing: ArrayIndexing): Unit = { + arrayIndexing.getArrayExpression.accept(this) + arrayIndexing.getIndexExpression.accept(this) + extendCfg(arrayIndexing) + } + + override def visit(binaryExpression: BinaryExpression): Unit = { + binaryExpression.getLeft.accept(this) + binaryExpression.getRight.accept(this) + extendCfg(binaryExpression) + } + + override def visit(astAND: AndExpression): Unit = { + astAND.getLeft.accept(this) + val entry = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + astAND.getRight.accept(this) + fringe = fringe.add(entry.setCfgEdgeType(FalseEdge)) + extendCfg(astAND) + } + + override def visit(astOR: OrExpression): Unit = { + astOR.getLeft.accept(this) + val entry = fringe + fringe = fringe.setCfgEdgeType(FalseEdge) + astOR.getRight.accept(this) + fringe = fringe.add(entry.setCfgEdgeType(TrueEdge)) + extendCfg(astOR) + } + + override def visit(breakStatement: BreakStatement): Unit = { + val mappedBreak = adapter.mapNode(breakStatement) + extendCfg(mappedBreak) + // Under normal conditions this is always true. + // But if the parser missed a loop or switch statement, breakStack + // might by empty. + if (breakStack.numberOfLayers > 0) { + fringe = Nil + breakStack.store(mappedBreak) + } + } + + override def visit(castExpression: CastExpression): Unit = { + castExpression.getCastExpression.accept(this) + extendCfg(castExpression) + } + + // TODO we do not handle the 'targetFunc' field of callExpression yet. + // This leads to not correctly handling calls via function pointers. + // Fix this once we change CALL side representation for this. + override def visit(callExpression: CallExpressionBase): Unit = { + callExpression.getArgumentList.accept(this) + extendCfg(callExpression) + } + + override def visit(classDefStatement: ClassDefStatement): Unit = { + // Class defs are not put into the control flow in CPG format. + } + + override def visit(compoundStatement: CompoundStatement): Unit = { + compoundStatement.getStatements.asScala.foreach { statement => + statement.accept(this) + } + } + + override def visit(condition: Condition): Unit = { + condition.getExpression.accept(this) + } + + // TODO we would prefer to unify conditional expressions and control structures. + // The data flow tracker cannot deal with this correctly, so we use a + // CALL with nonstandard control flow (argument evaluation order) instgead. + override def visit(conditionalExpression: ConditionalExpression): Unit = { + val condition = conditionalExpression.getChild(0) + val trueExpression = conditionalExpression.getChild(1) + val falseExpression = conditionalExpression.getChild(2) + + condition.accept(this) + val fromCond = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + trueExpression.accept(this) + val fromTrue = fringe + fringe = fromCond.setCfgEdgeType(FalseEdge) + falseExpression.accept(this) + fringe = fringe.add(fromTrue) + extendCfg(conditionalExpression) + } + + override def visit(continueStatement: ContinueStatement): Unit = { + val mappedContinue = adapter.mapNode(continueStatement) + extendCfg(mappedContinue) + // Under normal conditions this is always true. + // But if the parser missed a loop statement, continueStack + // might by empty. + if (continueStack.numberOfLayers > 0) { + fringe = Nil + continueStack.store(mappedContinue) + } + } + + override def visit(constant: Constant): Unit = { + extendCfg(constant) + } + + override def visit(doStatement: DoStatement): Unit = { + breakStack.pushLayer() + continueStack.pushLayer() + + markerStack = None :: markerStack + doStatement.getStatement.accept(this) + + fringe = fringe.add(continueStack.getTopElements, AlwaysEdge) + + Option(doStatement.getCondition) match { + case Some(condition) => + condition.accept(this) + val conditionFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + + extendCfg(markerStack.head.get) + + fringe = conditionFringe.setCfgEdgeType(FalseEdge) + case None => + // We only get here if the parser missed the condition. + // In this case doing nothing here means that we have + // no CFG edge to the loop start because we default + // to an always false condition. + } + fringe = fringe.add(breakStack.getTopElements, AlwaysEdge) + + markerStack = markerStack.tail + breakStack.popLayer() + continueStack.popLayer() + } + + override def visit(elseStatement: ElseStatement): Unit = { + acceptChildren(elseStatement) + } + + override def visit(expression: Expression): Unit = { + // We only end up here for expressions chained by ','. + // Those expressions are than the children of the expression + // given as parameter. + val classOfExpression = expression.getClass + if (classOfExpression != classOf[Expression]) { + throw new RuntimeException( + s"Only direct instances of Expressions expected " + + s"but ${classOfExpression.getSimpleName} found") + } + + acceptChildren(expression) + } + + override def visit(expressionStatement: ExpressionStatement): Unit = { + Option(expressionStatement.getExpression).foreach(_.accept(this)) + } + + override def visit(forInit: ForInit): Unit = { + acceptChildren(forInit) + } + + override def visit(forStatement: ForStatement): Unit = { + breakStack.pushLayer() + continueStack.pushLayer() + + Option(forStatement.getForInitExpression).foreach(_.accept(this)) + + markerStack = None :: markerStack + val conditionOption = Option(forStatement.getCondition) + val conditionFringe = + conditionOption match { + case Some(condition) => + condition.accept(this) + val storedFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + storedFringe + case None => Nil + } + + forStatement.getStatement.accept(this) + + fringe = fringe.add(continueStack.getTopElements, AlwaysEdge) + + Option(forStatement.getForLoopExpression).foreach(_.accept(this)) + + markerStack.head.foreach(extendCfg) + + fringe = conditionFringe + .setCfgEdgeType(FalseEdge) + .add(breakStack.getTopElements, AlwaysEdge) + + markerStack = markerStack.tail + breakStack.popLayer() + continueStack.popLayer() + } + + override def visit(functionDef: FunctionDef): Unit = { + functionDef.getContent.accept(this) + } + + override def visit(gotoStatement: GotoStatement): Unit = { + val mappedGoto = adapter.mapNode(gotoStatement) + extendCfg(mappedGoto) + fringe = Nil + gotos = (mappedGoto, gotoStatement.getTargetName) :: gotos + } + + override def visit(identifier: Identifier): Unit = { + extendCfg(identifier) + } + + override def visit(identifierDecl: IdentifierDecl): Unit = { + val assignment = identifierDecl.getAssignment + if (assignment != null) { + assignment.accept(this) + } + } + + override def visit(identifierDeclStatement: IdentifierDeclStatement): Unit = { + identifierDeclStatement.getIdentifierDeclList.asScala.foreach { identifierDecl => + identifierDecl.accept(this) + } + } + + override def visit(ifStatement: IfStatement): Unit = { + ifStatement.getCondition.accept(this) + val conditionFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + + ifStatement.getStatement.accept(this) + + Option(ifStatement.getElseNode) match { + case Some(elseStatement) => + val ifBlockFringe = fringe + fringe = conditionFringe.setCfgEdgeType(FalseEdge) + elseStatement.accept(this) + fringe = fringe.add(ifBlockFringe) + case None => + fringe = fringe.add(conditionFringe.setCfgEdgeType(FalseEdge)) + } + } + + override def visit(initializerList: InitializerList): Unit = { + // TODO figure out how to represent. + } + + override def visit(label: Label): Unit = { + val labelName = label.getLabelName + if (labelName.startsWith("case") || labelName.startsWith("default")) { + pendingCaseLabels = labelName :: pendingCaseLabels + } else { + pendingGotoLabels = labelName :: pendingGotoLabels + } + } + + override def visit(memberAccess: MemberAccess): Unit = { + acceptChildren(memberAccess) + extendCfg(memberAccess) + } + + override def visit(ptrMemberAccess: PtrMemberAccess): Unit = { + acceptChildren(ptrMemberAccess) + extendCfg(ptrMemberAccess) + } + + // TODO We here assume that the post inc/dec is executed like a normal operation + // and not at the end of the statement. + override def visit(postIncDecOperationExpression: PostIncDecOperationExpression): Unit = { + postIncDecOperationExpression.getChild(0).accept(this) + extendCfg(postIncDecOperationExpression) + } + + override def visit(returnStatement: ReturnStatement): Unit = { + Option(returnStatement.getReturnExpression).foreach(_.accept(this)) + val mappedReturnStatement = adapter.mapNode(returnStatement) + extendCfg(mappedReturnStatement) + fringe = Nil + returns = mappedReturnStatement :: returns + } + + override def visit(sizeofExpression: SizeofExpression): Unit = { + sizeofExpression.getChild(1).accept(this) + extendCfg(sizeofExpression) + } + + override def visit(sizeofOperand: SizeofOperand): Unit = { + sizeofOperand.getChildCount match { + case 0 => + // Operand is a type. We do not add the type to the CFG. + case 1 => + // Operand is an expression. + sizeofOperand.getChild(0).accept(this) + } + } + + override def visit(statement: Statement): Unit = { + if (statement.getChildCount != 0) { + throw new RuntimeException("Unhandled statement type: " + statement.getClass) + } + } + + override def visit(switchStatement: SwitchStatement): Unit = { + switchStatement.getCondition.accept(this) + val conditionFringe = fringe.setCfgEdgeType(CaseEdge) + fringe = Nil + + // We can only push the break and case stacks after we processed the condition + // in order to allow for nested switches with no nodes CFG nodes in between + // an outer switch case label and the inner switch condition. + // This is ok because in C/C++ it is not allowed to have another switch + // statement in the condition of a switch statement. + breakStack.pushLayer() + caseStack.pushLayer() + + switchStatement.getStatement.accept(this) + val switchFringe = fringe + + caseStack.getTopElements.foreach { + case (caseNode, _) => + fringe = conditionFringe + extendCfg(caseNode) + } + + val hasDefaultCase = caseStack.getTopElements.exists { + case (_, isDefault) => + isDefault + } + + fringe = switchFringe.add(breakStack.getTopElements, AlwaysEdge) + + if (!hasDefaultCase) { + fringe = fringe.add(conditionFringe) + } + + breakStack.popLayer() + caseStack.popLayer() + } + + override def visit(throwStatement: ThrowStatement): Unit = { + val throwExpression = throwStatement.getThrowExpression + if (throwExpression != null) { + throwExpression.accept(this) + } + // TODO at the moment we do not handle exception handling + // and thus simply ignore the influence of 'throw' on the + // cfg. + } + + override def visit(tryStatement: TryStatement): Unit = { + // TODO at the moment we do not handle exception handling + // and thus pretend the try does not exist. + Option(tryStatement.getContent).foreach(_.accept(this)) + Option(tryStatement.getFinallyContent).foreach(_.accept(this)) + } + + override def visit(unaryExpression: UnaryExpression): Unit = { + Option(unaryExpression.getChild(1)) match { + case Some(child) => + // Child 0 is the operator child 1 is the operand. + child.accept(this) + case None => + // We get here for `new` expression. + } + + extendCfg(unaryExpression) + } + + override def visit(whileStatement: WhileStatement): Unit = { + breakStack.pushLayer() + continueStack.pushLayer() + + markerStack = None :: markerStack + whileStatement.getCondition.accept(this) + val conditionFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + + whileStatement.getStatement.accept(this) + + fringe = fringe.add(continueStack.getTopElements, AlwaysEdge) + + extendCfg(markerStack.head.get) + + fringe = conditionFringe + .setCfgEdgeType(FalseEdge) + .add(breakStack.getTopElements, AlwaysEdge) + + markerStack = markerStack.tail + breakStack.popLayer() + continueStack.popLayer() + } + + private def acceptChildren(node: AstNode): Unit = { + node.getChildIterator.forEachRemaining((child: AstNode) => child.accept(this)) + } +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/cfg/LayeredStack.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/cfg/LayeredStack.scala new file mode 100644 index 0000000..938c4a0 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/cfg/LayeredStack.scala @@ -0,0 +1,31 @@ +package io.shiftleft.fuzzyc2cpg.cfg + +class LayeredStack[ElementType] { + private case class StackElement(elements: List[ElementType] = List()) { + def addNode(element: ElementType): StackElement = { + StackElement(element :: elements) + } + } + + private var stack = List[StackElement]() + + def pushLayer(): Unit = { + stack = StackElement() :: stack + } + + def popLayer(): Unit = { + stack = stack.tail + } + + def store(node: ElementType): Unit = { + stack = stack.head.addNode(node) :: stack.tail + } + + def getTopElements: List[ElementType] = { + stack.head.elements + } + + def numberOfLayers: Int = { + stack.size + } +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/output/CpgOutputModule.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/CpgOutputModule.scala new file mode 100644 index 0000000..a9373a2 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/CpgOutputModule.scala @@ -0,0 +1,24 @@ +package io.shiftleft.fuzzyc2cpg.output + +import io.shiftleft.proto.cpg.Cpg.CpgStruct +import java.io.IOException + +/** + * The CpgOutputModule describes the format of the CPG graph, e.g, TinkerGraph. + */ +trait CpgOutputModule { + + /** + * Identifier for this output module which can be used to derive a name for + * e.g. a resulting output file. + */ + def setOutputIdentifier(identifier: String): Unit + + /** + * Persists the individual CPG. + * + * @param cpg a CPG to be persisted (in memory or disk) + */ + @throws[IOException] + def persistCpg(cpg: CpgStruct.Builder): Unit +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/output/CpgOutputModuleFactory.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/CpgOutputModuleFactory.scala new file mode 100644 index 0000000..784a70a --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/CpgOutputModuleFactory.scala @@ -0,0 +1,24 @@ +package io.shiftleft.fuzzyc2cpg.output + +import java.io.IOException + +/** + * Output module factory. + */ +trait CpgOutputModuleFactory { + + /** + * A CpgOutputModule associated with the given factory. + * + * @return a singleton output module + */ + @throws[IOException] + def create(): CpgOutputModule + + /** + * A finalization method that potentially combines all CPGs added to any of the + * output module created through this factory. + */ + @throws[IOException] + def persist(): Unit +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OutputModule.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OutputModule.scala new file mode 100644 index 0000000..0dc45fe --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OutputModule.scala @@ -0,0 +1,14 @@ +package io.shiftleft.fuzzyc2cpg.output.overflowdb + +import java.util.concurrent.BlockingQueue + +import io.shiftleft.fuzzyc2cpg.output.CpgOutputModule +import io.shiftleft.proto.cpg.Cpg.CpgStruct + +class OutputModule(queue: BlockingQueue[CpgStruct.Builder]) extends CpgOutputModule { + + override def setOutputIdentifier(identifier: String): Unit = {} + + override def persistCpg(cpg: CpgStruct.Builder): Unit = queue.add(cpg) + +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OutputModuleFactory.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OutputModuleFactory.scala new file mode 100644 index 0000000..534c43d --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OutputModuleFactory.scala @@ -0,0 +1,28 @@ +package io.shiftleft.fuzzyc2cpg.output.overflowdb + +import java.util.concurrent.BlockingQueue + +import io.shiftleft.fuzzyc2cpg.output.{CpgOutputModule, CpgOutputModuleFactory} +import io.shiftleft.proto.cpg.Cpg +import io.shiftleft.proto.cpg.Cpg.CpgStruct +import org.slf4j.LoggerFactory + +class OutputModuleFactory(outputPath: String, queue: BlockingQueue[CpgStruct.Builder]) extends CpgOutputModuleFactory { + + private val logger = LoggerFactory.getLogger(getClass) + private val writer = new OverflowDbWriter(outputPath, queue) + val writerThread = new Thread(writer) + writerThread.start() + + override def create(): CpgOutputModule = new OutputModule(queue) + + override def persist(): Unit = { + try { + val endMarker = Cpg.CpgStruct.newBuilder().addNode(Cpg.CpgStruct.Node.newBuilder().setKey(-1)) + queue.put(endMarker) + } catch { + case _: InterruptedException => logger.warn("Interrupted during persist operation") + } + writerThread.join() + } +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OverflowDbWriter.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OverflowDbWriter.scala new file mode 100644 index 0000000..b5b0331 --- /dev/null +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/output/overflowdb/OverflowDbWriter.scala @@ -0,0 +1,44 @@ +package io.shiftleft.fuzzyc2cpg.output.overflowdb + +import java.util.concurrent.BlockingQueue + +import better.files.File +import io.shiftleft.codepropertygraph.cpgloading.{CpgLoader, ProtoToCpg} +import io.shiftleft.proto.cpg.Cpg.CpgStruct +import org.slf4j.LoggerFactory +import overflowdb.OdbConfig + +class OverflowDbWriter(outputPath: String, queue: BlockingQueue[CpgStruct.Builder]) extends Runnable { + + private val logger = LoggerFactory.getLogger(getClass) + + override def run(): Unit = { + + val outFile = File(outputPath) + if (outputPath != "" && outFile.exists) { + logger.info("Output file exists, removing: " + outputPath) + outFile.delete() + } + val odbConfig = OdbConfig.withDefaults.withStorageLocation(outputPath) + val protoToCpg = new ProtoToCpg(odbConfig) + try { + var terminate = false; + while (!terminate) { + val subCpg = queue.take() + if (subCpg.getNodeCount == 1 && subCpg.getNode(0).getKey == -1) { + terminate = true + } else { + protoToCpg.addNodes(subCpg.getNodeList) + protoToCpg.addEdges(subCpg.getEdgeList) + } + } + + } catch { + case _: InterruptedException => logger.warn("Interrupted OverflowDbWriter.") + } finally { + val cpg = protoToCpg.build + CpgLoader.createIndexes(cpg) + cpg.graph.close() + } + } +} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/AstCreationPass.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/AstCreationPass.scala index 3f3e2e1..ed58b40 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/AstCreationPass.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/AstCreationPass.scala @@ -2,9 +2,10 @@ package io.shiftleft.fuzzyc2cpg.passes import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} +import io.shiftleft.fuzzyc2cpg.Utils.getGlobalNamespaceBlockFullName import io.shiftleft.fuzzyc2cpg.passes.astcreation.{AntlrCModuleParserDriver, AstVisitor} import io.shiftleft.fuzzyc2cpg.{Defines, Global} -import io.shiftleft.passes.{DiffGraph, IntervalKeyPool, ParallelCpgPass} +import io.shiftleft.passes.{DiffGraph, IntervalKeyPool, KeyPool, ParallelCpgPass} import org.slf4j.LoggerFactory /** @@ -27,12 +28,11 @@ class AstCreationPass(filenames: List[String], cpg: Cpg, keyPool: IntervalKeyPoo diffGraph.addNode(fileNode) val namespaceBlock = nodes.NewNamespaceBlock( name = Defines.globalNamespaceName, - fullName = CMetaDataPass.getGlobalNamespaceBlockFullName(Some(fileNode.name)) + fullName = getGlobalNamespaceBlockFullName(Some(fileNode.name)) ) diffGraph.addNode(fileNode) diffGraph.addNode(namespaceBlock) diffGraph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) - diffGraph.addEdge(fileNode, namespaceBlock, EdgeTypes.AST) val driver = createDriver(fileNode, namespaceBlock) tryToParse(driver, filename, diffGraph) diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPass.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPass.scala index 2fe5414..eebf786 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPass.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPass.scala @@ -1,8 +1,9 @@ package io.shiftleft.fuzzyc2cpg.passes import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Languages, nodes} +import io.shiftleft.codepropertygraph.generated.{Languages, nodes} import io.shiftleft.fuzzyc2cpg.Defines +import io.shiftleft.fuzzyc2cpg.Utils.getGlobalNamespaceBlockFullName import io.shiftleft.passes.{CpgPass, DiffGraph, KeyPool} /** @@ -20,12 +21,9 @@ class CMetaDataPass(cpg: Cpg, keyPool: Option[KeyPool] = None) extends CpgPass(c def addAnyNamespaceBlock(diffGraph: DiffGraph.Builder): Unit = { val node = nodes.NewNamespaceBlock( name = Defines.globalNamespaceName, - fullName = CMetaDataPass.getGlobalNamespaceBlockFullName(None) + fullName = getGlobalNamespaceBlockFullName(None) ) - val fileWithNoName = nodes.NewFile(name = "") - diffGraph.addNode(fileWithNoName) diffGraph.addNode(node) - diffGraph.addEdge(node, fileWithNoName, EdgeTypes.SOURCE_FILE) } val diffGraph = DiffGraph.newBuilder @@ -34,16 +32,3 @@ class CMetaDataPass(cpg: Cpg, keyPool: Option[KeyPool] = None) extends CpgPass(c Iterator(diffGraph.build()) } } - -object CMetaDataPass { - - def getGlobalNamespaceBlockFullName(fileNameOption: Option[String]): String = { - fileNameOption match { - case Some(fileName) => - s"$fileName:${Defines.globalNamespaceName}" - case None => - Defines.globalNamespaceName - } - } - -} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPass.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPass.scala index 3a0aad6..248805f 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPass.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPass.scala @@ -1,36 +1,434 @@ package io.shiftleft.fuzzyc2cpg.passes import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.passes.{DiffGraph, IntervalKeyPool, ParallelCpgPass} -import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.fuzzyc2cpg.passes.cfgcreation.CfgCreator +import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators, nodes} +import io.shiftleft.fuzzyc2cpg.adapter.{AlwaysEdge, CaseEdge, CfgEdgeType, FalseEdge, TrueEdge} +import io.shiftleft.fuzzyc2cpg.cfg.LayeredStack import io.shiftleft.semanticcpg.language._ +import org.slf4j.LoggerFactory -/** - * A pass that creates control flow graphs from abstract syntax trees. - * - * Control flow graphs can be calculated independently per method. - * Therefore, we inherit from `ParallelCpgPass`. As for other - * parallel passes, we provide a key pool that is split into equal - * parts, each of which is assigned to exactly one method prior - * to branching off into parallel computation. This ensures id - * stability over multiple runs. - * - * Note: the version of OverflowDB that we currently use as a - * storage backend does not assign ids to edges and this pass - * only creates edges at the moment. Therefore, we could do - * without key pools, however, this would not only deviate - * from the standard template for parallel CPG passes but it - * is also likely to bite us later, whenever we find that - * adding nodes in this pass or adding edge ids to the - * backend becomes necessary. - * */ class CfgCreationPass(cpg: Cpg, keyPool: IntervalKeyPool) extends ParallelCpgPass[nodes.Method](cpg, keyPools = Some(keyPool.split(cpg.method.size))) { override def partIterator: Iterator[nodes.Method] = cpg.method.iterator override def runOnPart(method: nodes.Method): Iterator[DiffGraph] = - new CfgCreator(method).run() + new CfgCreatorForMethod(method).run() + +} + +class CfgCreatorForMethod(entryNode: nodes.Method) { + + private implicit class FringeWrapper(fringe: List[FringeElement]) { + def setCfgEdgeType(cfgEdgeType: CfgEdgeType): List[FringeElement] = { + fringe.map { + case FringeElement(node, _) => + FringeElement(node, cfgEdgeType) + } + } + def add(node: nodes.CfgNode, cfgEdgeType: CfgEdgeType): List[FringeElement] = + FringeElement(node, cfgEdgeType) :: fringe + + def add(ns: List[nodes.CfgNode], cfgEdgeType: CfgEdgeType): List[FringeElement] = + ns.map(node => FringeElement(node, cfgEdgeType)) ++ fringe + + def add(otherFringe: List[FringeElement]): List[FringeElement] = + otherFringe ++ fringe + } + + private val logger = LoggerFactory.getLogger(getClass) + val diffGraph: DiffGraph.Builder = DiffGraph.newBuilder + + private var fringe = List[FringeElement]().add(entryNode, AlwaysEdge) + private var markerStack = List[Option[nodes.CfgNode]]() + private case class FringeElement(node: nodes.CfgNode, cfgEdgeType: CfgEdgeType) + private var labeledNodes = Map[String, nodes.CfgNode]() + private var pendingGotoLabels = List[String]() + private var pendingCaseLabels = List[String]() + private var returns = List[nodes.CfgNode]() + private val breakStack = new LayeredStack[nodes.CfgNode]() + private val continueStack = new LayeredStack[nodes.CfgNode]() + private val caseStack = new LayeredStack[(nodes.CfgNode, Boolean)]() + private var gotos = List[(nodes.CfgNode, String)]() + + def run(): Iterator[DiffGraph] = { + postOrderLeftToRightExpand(entryNode) + connectGotosAndLabels() + connectReturnsToExit() + Iterator(diffGraph.build) + } + + private def postOrderLeftToRightExpand(node: nodes.AstNode): Unit = { + node match { + case n: nodes.ControlStructure => + handleControlStructure(n) + case n: nodes.JumpTarget => + handleJumpTarget(n) + case call: nodes.Call if call.name == Operators.conditional => + handleConditionalExpression(call) + case call: nodes.Call if call.name == Operators.logicalAnd => + handleAndExpression(call) + case call: nodes.Call if call.name == Operators.logicalOr => + handleOrExpression(call) + case call: nodes.Call => + handleCall(call) + case identifier: nodes.Identifier => + handleIdentifier(identifier) + case literal: nodes.Literal => + handleLiteral(literal) + case actualRet: nodes.Return => + handleReturn(actualRet) + case formalRet: nodes.MethodReturn => + handleFormalReturn(formalRet) + case n: nodes.AstNode => + expandChildren(n) + } + } + + private def handleCall(call: nodes.Call): Unit = { + expandChildren(call) + extendCfg(call) + } + + private def handleIdentifier(identifier: nodes.Identifier): Unit = { + extendCfg(identifier) + } + + private def handleLiteral(literal: nodes.Literal): Unit = { + extendCfg(literal) + } + + private def handleReturn(actualRet: nodes.Return): Unit = { + expandChildren(actualRet) + extendCfg(actualRet) + fringe = Nil + returns = actualRet :: returns + } + + private def handleFormalReturn(formalRet: nodes.MethodReturn): Unit = { + extendCfg(formalRet) + } + + private def connectGotosAndLabels(): Unit = { + gotos.foreach { + case (goto, label) => + labeledNodes.get(label) match { + case Some(labeledNode) => + // TODO: CFG_EDGE_TYPE isn't defined for non-proto CPGs + // .addProperty(EdgeProperty.CFG_EDGE_TYPE, AlwaysEdge.toString) + diffGraph.addEdge( + goto, + labeledNode, + EdgeTypes.CFG + ) + case None => + logger.info("Unable to wire goto statement. Missing label {}.", label) + } + } + } + + private def connectReturnsToExit(): Unit = { + returns.foreach( + diffGraph.addEdge( + _, + entryNode.methodReturn, + EdgeTypes.CFG + ) + ) + } + + private def handleJumpTarget(n: nodes.JumpTarget): Unit = { + val labelName = n.name + if (labelName.startsWith("case") || labelName.startsWith("default")) { + pendingCaseLabels = labelName :: pendingCaseLabels + } else { + pendingGotoLabels = labelName :: pendingGotoLabels + } + } + + private def handleConditionalExpression(call: nodes.Call): Unit = { + val condition = call.argument(1) + val trueExpression = call.argument(2) + val falseExpression = call.argument(3) + + postOrderLeftToRightExpand(condition) + val fromCond = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + postOrderLeftToRightExpand(trueExpression) + val fromTrue = fringe + fringe = fromCond.setCfgEdgeType(FalseEdge) + postOrderLeftToRightExpand(falseExpression) + fringe = fringe.add(fromTrue) + extendCfg(call) + } + + private def handleAndExpression(call: Call): Unit = { + postOrderLeftToRightExpand(call.argument(1)) + val entry = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + postOrderLeftToRightExpand(call.argument(2)) + fringe = fringe.add(entry.setCfgEdgeType(FalseEdge)) + extendCfg(call) + } + + private def handleOrExpression(call: Call): Unit = { + val left = call.argument(1) + val right = call.argument(2) + postOrderLeftToRightExpand(left) + val entry = fringe + fringe = fringe.setCfgEdgeType(FalseEdge) + postOrderLeftToRightExpand(right) + fringe = fringe.add(entry.setCfgEdgeType(TrueEdge)) + extendCfg(call) + } + + private def handleBreakStatement(node: nodes.ControlStructure): Unit = { + extendCfg(node) + // Under normal conditions this is always true. + // But if the parser missed a loop or switch statement, breakStack + // might by empty. + if (breakStack.numberOfLayers > 0) { + fringe = Nil + breakStack.store(node) + } + } + + private def handleContinueStatement(node: nodes.ControlStructure): Unit = { + extendCfg(node) + // Under normal conditions this is always true. + // But if the parser missed a loop statement, continueStack + // might by empty. + if (continueStack.numberOfLayers > 0) { + fringe = Nil + continueStack.store(node) + } + } + + private def handleWhileStatement(node: nodes.ControlStructure): Unit = { + breakStack.pushLayer() + continueStack.pushLayer() + + markerStack = None :: markerStack + node.start.condition.headOption.foreach(postOrderLeftToRightExpand) + val conditionFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + + node.start.whenTrue.l.foreach(postOrderLeftToRightExpand) + fringe = fringe.add(continueStack.getTopElements, AlwaysEdge) + extendCfg(markerStack.head.get) + + fringe = conditionFringe + .setCfgEdgeType(FalseEdge) + .add(breakStack.getTopElements, AlwaysEdge) + + markerStack = markerStack.tail + breakStack.popLayer() + continueStack.popLayer() + } + + private def handleDoStatement(node: nodes.ControlStructure): Unit = { + breakStack.pushLayer() + continueStack.pushLayer() + + markerStack = None :: markerStack + node.astChildren.filter(_.order(1)).foreach(postOrderLeftToRightExpand) + fringe = fringe.add(continueStack.getTopElements, AlwaysEdge) + + node.start.condition.headOption match { + case Some(condition) => + postOrderLeftToRightExpand(condition) + val conditionFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + + extendCfg(markerStack.head.get) + + fringe = conditionFringe.setCfgEdgeType(FalseEdge) + case None => + // We only get here if the parser missed the condition. + // In this case doing nothing here means that we have + // no CFG edge to the loop start because we default + // to an always false condition. + } + fringe = fringe.add(breakStack.getTopElements, AlwaysEdge) + + markerStack = markerStack.tail + breakStack.popLayer() + continueStack.popLayer() + } + + private def handleForStatement(node: nodes.ControlStructure): Unit = { + breakStack.pushLayer() + continueStack.pushLayer() + + val children = node.astChildren.l + val initExprOption = children.find(_.order == 1) + val conditionOption = children.find(_.order == 2) + val loopExprOption = children.find(_.order == 3) + val statementOption = children.find(_.order == 4) + + initExprOption.foreach(postOrderLeftToRightExpand) + + markerStack = None :: markerStack + val conditionFringe = + conditionOption match { + case Some(condition) => + postOrderLeftToRightExpand(condition) + val storedFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + storedFringe + case None => Nil + } + + statementOption.foreach(postOrderLeftToRightExpand) + + fringe = fringe.add(continueStack.getTopElements, AlwaysEdge) + + loopExprOption.foreach(postOrderLeftToRightExpand) + + markerStack.head.foreach(extendCfg) + + fringe = conditionFringe + .setCfgEdgeType(FalseEdge) + .add(breakStack.getTopElements, AlwaysEdge) + + markerStack = markerStack.tail + breakStack.popLayer() + continueStack.popLayer() + } + + private def handleGotoStatement(node: nodes.ControlStructure): Unit = { + extendCfg(node) + fringe = Nil + // TODO: the target name should be in the AST + node.code.split(" ").lastOption.map(x => x.slice(0, x.length - 1)).foreach { target => + gotos = (node, target) :: gotos + } + } + + private def handleIfStatement(node: nodes.ControlStructure): Unit = { + node.start.condition.foreach(postOrderLeftToRightExpand) + val conditionFringe = fringe + fringe = fringe.setCfgEdgeType(TrueEdge) + node.start.whenTrue.foreach(postOrderLeftToRightExpand) + node.start.whenFalse + .map { elseStatement => + val ifBlockFringe = fringe + fringe = conditionFringe.setCfgEdgeType(FalseEdge) + postOrderLeftToRightExpand(elseStatement) + fringe = fringe.add(ifBlockFringe) + } + .headOption + .getOrElse { + fringe = fringe.add(conditionFringe.setCfgEdgeType(FalseEdge)) + } + } + + private def handleSwitchStatement(node: nodes.ControlStructure): Unit = { + node.start.condition.foreach(postOrderLeftToRightExpand) + val conditionFringe = fringe.setCfgEdgeType(CaseEdge) + fringe = Nil + + // We can only push the break and case stacks after we processed the condition + // in order to allow for nested switches with no nodes CFG nodes in between + // an outer switch case label and the inner switch condition. + // This is ok because in C/C++ it is not allowed to have another switch + // statement in the condition of a switch statement. + breakStack.pushLayer() + caseStack.pushLayer() + + node.start.whenTrue.foreach(postOrderLeftToRightExpand) + val switchFringe = fringe + + caseStack.getTopElements.foreach { + case (caseNode, _) => + fringe = conditionFringe + extendCfg(caseNode) + } + + val hasDefaultCase = caseStack.getTopElements.exists { + case (_, isDefault) => + isDefault + } + + fringe = switchFringe.add(breakStack.getTopElements, AlwaysEdge) + + if (!hasDefaultCase) { + fringe = fringe.add(conditionFringe) + } + + breakStack.popLayer() + caseStack.popLayer() + } + + private def handleControlStructure(node: nodes.ControlStructure): Unit = { + node.parserTypeName match { + case "BreakStatement" => + handleBreakStatement(node) + case "ContinueStatement" => + handleContinueStatement(node) + case "WhileStatement" => + handleWhileStatement(node) + case "DoStatement" => + handleDoStatement(node) + case "ForStatement" => + handleForStatement(node) + case "GotoStatement" => + handleGotoStatement(node) + case "IfStatement" => + handleIfStatement(node) + case "ElseStatement" => + expandChildren(node) + case "SwitchStatement" => + handleSwitchStatement(node) + case _ => + } + } + + private def expandChildren(node: nodes.AstNode): Unit = { + val children = node.astChildren.l + children.foreach(postOrderLeftToRightExpand) + } + + private def extendCfg(dstNode: nodes.CfgNode): Unit = { + fringe.foreach { + case FringeElement(srcNode, _) => + // TODO add edge CFG edge type in CPG spec + // val props = List(("CFG_EDGE_TYPE", cfgEdgeType.toString)) + diffGraph.addEdge( + srcNode, + dstNode, + EdgeTypes.CFG + ) + } + fringe = Nil.add(dstNode, AlwaysEdge) + + if (markerStack.nonEmpty) { + // Up until the first none None stack element we replace the Nones with Some(dstNode) + val leadingNoneLength = markerStack.segmentLength(_.isEmpty, 0) + markerStack = List.fill(leadingNoneLength)(Some(dstNode)) ++ markerStack + .drop(leadingNoneLength) + } + + if (pendingGotoLabels.nonEmpty) { + pendingGotoLabels.foreach { label => + labeledNodes = labeledNodes + (label -> dstNode) + } + pendingGotoLabels = List() + } + + // TODO at the moment we discard the case labels + if (pendingCaseLabels.nonEmpty) { + // Under normal conditions this is always true. + // But if the parser missed a switch statement, caseStack + // might by empty. + if (caseStack.numberOfLayers > 0) { + val containsDefaultLabel = pendingCaseLabels.contains("default") + caseStack.store((dstNode, containsDefaultLabel)) + } + pendingCaseLabels = List() + } + } } diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/StubRemovalPass.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/StubRemovalPass.scala index fe51d94..6a2b427 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/StubRemovalPass.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/StubRemovalPass.scala @@ -4,24 +4,22 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.passes.{DiffGraph, ParallelCpgPass} import io.shiftleft.semanticcpg.language._ import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.codepropertygraph.generated.nodes.Method /** * A pass that ensures that for any method m for which a body exists, * there are no more method stubs for corresponding declarations. * */ class StubRemovalPass(cpg: Cpg) extends ParallelCpgPass[nodes.Method](cpg) { + override def partIterator: Iterator[nodes.Method] = + cpg.method.isNotStub.iterator - private val sigToMethodWithDef = cpg.method.isNotStub.map(m => (m.signature -> true)).toMap - - override def partIterator: Iterator[Method] = - cpg.method.isStub.toList - .filter(m => sigToMethodWithDef.contains(m.signature)) - .iterator - - override def runOnPart(stub: Method): Iterator[DiffGraph] = { + override def runOnPart(method: nodes.Method): Iterator[DiffGraph] = { val diffGraph = DiffGraph.newBuilder - stub.ast.foreach(diffGraph.removeNode) + cpg.method.isStub.where(m => m.signature == method.signature).foreach { stubMethod => + stubMethod.ast.l.foreach { node => + diffGraph.removeNode(node.id2()) + } + } Iterator(diffGraph.build) } } diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/TypeNodePass.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/TypeNodePass.scala deleted file mode 100644 index 57f61f3..0000000 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/TypeNodePass.scala +++ /dev/null @@ -1,21 +0,0 @@ -package io.shiftleft.fuzzyc2cpg.passes - -import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.passes.{CpgPass, DiffGraph, KeyPool} -import io.shiftleft.codepropertygraph.generated.nodes - -class TypeNodePass(usedTypes: List[String], cpg: Cpg, keyPool: Option[KeyPool] = None) - extends CpgPass(cpg, "types", keyPool) { - override def run(): Iterator[DiffGraph] = { - val diffGraph = DiffGraph.newBuilder - usedTypes.sorted.foreach { typeName => - val node = nodes.NewType( - name = typeName, - fullName = typeName, - typeDeclFullName = typeName - ) - diffGraph.addNode(node) - } - Iterator(diffGraph.build) - } -} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AntlrParserDriver.java b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AntlrParserDriver.java index 79971ea..baed4e9 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AntlrParserDriver.java +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AntlrParserDriver.java @@ -7,7 +7,6 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewFile; import io.shiftleft.fuzzyc2cpg.ast.AstNode; import io.shiftleft.fuzzyc2cpg.ast.AstNodeBuilder; -import io.shiftleft.fuzzyc2cpg.ast.logical.statements.CompoundStatement; import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriverObserver; import io.shiftleft.fuzzyc2cpg.parser.CommonParserContext; import io.shiftleft.fuzzyc2cpg.parser.TokenSubStream; @@ -33,6 +32,7 @@ import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.ParseTreeListener; import org.antlr.v4.runtime.tree.ParseTreeWalker; +import io.shiftleft.codepropertygraph.generated.nodes.File; import scala.Some; import scala.collection.immutable.List$; @@ -54,7 +54,6 @@ abstract public class AntlrParserDriver { public DiffGraph.Builder cpg; private final List observers = new ArrayList<>(); private NewFile fileNode; - private Parser antlrParser; public AntlrParserDriver() { super(); @@ -102,32 +101,6 @@ private void handleHiddenTokens(String filename) { } } - public void parseAndWalkTokenStream(TokenSubStream tokens) - throws ParserException { - filename = ""; - stream = tokens; - ParseTree tree = parseTokenStream(tokens); - walkTree(tree); - } - - - public ParseTree parseAndWalkString(String input) throws ParserException { - ParseTree tree = parseString(input); - walkTree(tree); - return tree; - } - - public CompoundStatement getResult() { - return (CompoundStatement) builderStack.peek().getItem(); - } - - public ParseTree parseString(String input) throws ParserException { - CharStream inputStream = CharStreams.fromString(input); - Lexer lex = createLexer(inputStream); - TokenSubStream tokens = new TokenSubStream(lex); - ParseTree tree = parseTokenStream(tokens); - return tree; - } public ParseTree parseTokenStream(TokenSubStream tokens) throws ParserException { @@ -138,14 +111,6 @@ public ParseTree parseTokenStream(TokenSubStream tokens) return returnTree; } - public void setAntlrParser(Parser parser) { - antlrParser = parser; - } - - public Parser getAntlrParser() { - return antlrParser; - } - protected TokenSubStream createTokenStreamFromFile(String filename) throws ParserException { diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AstCreator.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AstCreator.scala index 772567e..6f1e914 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AstCreator.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/AstCreator.scala @@ -54,6 +54,7 @@ import io.shiftleft.fuzzyc2cpg.ast.statements.jump.{ ThrowStatement } import io.shiftleft.fuzzyc2cpg.ast.statements.{ExpressionStatement, IdentifierDeclStatement} +import io.shiftleft.fuzzyc2cpg.scope.Scope import io.shiftleft.passes.DiffGraph import io.shiftleft.proto.cpg.Cpg.{DispatchTypes, EvaluationStrategies} @@ -103,7 +104,7 @@ private[astcreation] class AstCreator(diffGraph: DiffGraph.Builder, "int" } - val signature = returnType + " " + astFunction.getFunctionSignature(false) + val signature = astFunction.getFunctionSignature(false) val location = astFunction.getLocation val method = nodes.NewMethod( @@ -164,15 +165,14 @@ private[astcreation] class AstCreator(diffGraph: DiffGraph.Builder, } else { "int" } - val location = astParameter.getLocation val parameter = nodes.NewMethodParameterIn( code = astParameter.getEscapedCodeStr, name = astParameter.getName, order = astParameter.getChildNumber + 1, evaluationStrategy = EvaluationStrategies.BY_VALUE.name(), typeFullName = registerType(parameterType), - lineNumber = location.startLine, - columnNumber = location.startPos + lineNumber = astParameter.getLocation.startLine, + columnNumber = astParameter.getLocation.startPos ) diffGraph.addNode(parameter) scope.addToScope(astParameter.getName, (parameter, parameterType)) @@ -873,7 +873,6 @@ private[astcreation] class AstCreator(diffGraph: DiffGraph.Builder, } private def newCallNode(astNode: AstNode, methodName: String): nodes.NewCall = { - val location = astNode.getLocation nodes.NewCall( name = methodName, dispatchType = DispatchTypes.STATIC_DISPATCH.name(), @@ -883,32 +882,30 @@ private[astcreation] class AstCreator(diffGraph: DiffGraph.Builder, code = astNode.getEscapedCodeStr, order = context.childNum, argumentIndex = context.childNum, - lineNumber = location.startLine, - columnNumber = location.startPos + lineNumber = astNode.getLocation.startLine, + columnNumber = astNode.getLocation.startPos ) } private def newUnknownNode(astNode: AstNode): nodes.NewUnknown = { - val location = astNode.getLocation nodes.NewUnknown( parserTypeName = astNode.getClass.getSimpleName, code = astNode.getEscapedCodeStr, order = context.childNum, argumentIndex = context.childNum, - lineNumber = location.startLine, - columnNumber = location.startPos + lineNumber = astNode.getLocation.startLine, + columnNumber = astNode.getLocation.startPos ) } private def newControlStructureNode(astNode: AstNode): nodes.NewControlStructure = { - val location = astNode.getLocation nodes.NewControlStructure( parserTypeName = astNode.getClass.getSimpleName, code = astNode.getEscapedCodeStr, order = context.childNum, argumentIndex = context.childNum, - lineNumber = location.startLine, - columnNumber = location.startPos + lineNumber = astNode.getLocation.startLine, + columnNumber = astNode.getLocation.startPos ) } diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/CModuleParserTreeListener.java b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/CModuleParserTreeListener.java index 21f0378..aadab70 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/CModuleParserTreeListener.java +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/CModuleParserTreeListener.java @@ -11,6 +11,7 @@ import io.shiftleft.fuzzyc2cpg.ast.statements.IdentifierDeclStatement; import io.shiftleft.fuzzyc2cpg.parser.CompoundItemAssembler; import io.shiftleft.fuzzyc2cpg.parser.ModuleFunctionParserInterface; +import io.shiftleft.fuzzyc2cpg.parser.modules.AntlrCModuleParserDriver; import io.shiftleft.fuzzyc2cpg.parser.modules.builder.FunctionDefBuilder; import io.shiftleft.fuzzyc2cpg.parser.shared.builders.ClassDefBuilder; import io.shiftleft.fuzzyc2cpg.parser.shared.builders.IdentifierDeclBuilder; diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/cfgcreation/Cfg.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/cfgcreation/Cfg.scala deleted file mode 100644 index 30e1caf..0000000 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/cfgcreation/Cfg.scala +++ /dev/null @@ -1,178 +0,0 @@ -package io.shiftleft.fuzzyc2cpg.passes.cfgcreation - -import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.fuzzyc2cpg.passes.cfgcreation.Cfg.CfgEdgeType -import org.slf4j.LoggerFactory - -/** - * A control flow graph that is under construction, consisting of: - * - * @param entryNode the control flow graph's first node, that is, - * the node to which a CFG that appends this CFG - * should attach itself to. - * @param edges control flow edges between nodes of the - * code property graph. - * @param fringe nodes of the CFG for which an outgoing edge type - * is already known but the destination node is not. - * These nodes are connected when another CFG is - * appended to this CFG. - * - * In addition to these three core building blocks, we store labels - * and jump statements that have not been resolved and may be - * resolvable as parent sub trees or sibblings are translated. - * - * @param labeledNodes labels contained in the abstract syntax tree - * from which this CPG was generated - * @param caseLabels labels beginning with "case" - * @param breaks unresolved breaks collected along the way - * @param continues unresolved continues collected along the way - * @param gotos unresolved gotos collected along the way - * - * */ -case class Cfg(entryNode: Option[nodes.CfgNode] = None, - edges: List[CfgEdge] = List(), - fringe: List[(nodes.CfgNode, CfgEdgeType)] = List(), - labeledNodes: Map[String, nodes.CfgNode] = Map(), - breaks: List[nodes.CfgNode] = List(), - continues: List[nodes.CfgNode] = List(), - caseLabels: List[nodes.CfgNode] = List(), - gotos: List[(nodes.CfgNode, String)] = List()) { - - import Cfg._ - - private val logger = LoggerFactory.getLogger(getClass) - - /** - * Create a new CFG in which `other` is appended - * to this CFG. All nodes of the fringe are connected - * to `other`'s entry node and the new fringe is - * `other`'s fringe. The diffgraphs, jumps, and labels - * are the sum of those present in `this` and `other`. - * - * */ - def ++(other: Cfg): Cfg = { - if (other == Cfg.empty) { - this - } else if (this == Cfg.empty) { - other - } else { - this.copy( - fringe = other.fringe, - edges = this.edges ++ other.edges ++ - edgesFromFringeTo(this, other.entryNode), - gotos = this.gotos ++ other.gotos, - labeledNodes = this.labeledNodes ++ other.labeledNodes, - breaks = this.breaks ++ other.breaks, - continues = this.continues ++ other.continues, - caseLabels = this.caseLabels ++ other.caseLabels - ) - } - } - - def withFringeEdgeType(cfgEdgeType: CfgEdgeType): Cfg = { - this.copy(fringe = fringe.map { case (x, _) => (x, cfgEdgeType) }) - } - - /** - * Upon completing traversal of the abstract syntax tree, - * this method creates CFG edges between gotos and - * respective labels. - * */ - def withResolvedGotos(): Cfg = { - val edges = gotos.flatMap { - case (goto, label) => - labeledNodes.get(label) match { - case Some(labeledNode) => - // TODO set edge type of Always once the backend - // supports it - Some(CfgEdge(goto, labeledNode, AlwaysEdge)) - case None => - logger.info("Unable to wire goto statement. Missing label {}.", label) - None - } - } - this.copy(edges = this.edges ++ edges) - } - -} - -case class CfgEdge(src: nodes.CfgNode, dst: nodes.CfgNode, edgeType: CfgEdgeType) - -object Cfg { - - /** - * The safe "null" Cfg. - * */ - val empty: Cfg = new Cfg() - - trait CfgEdgeType - object TrueEdge extends CfgEdgeType { - override def toString: String = "TrueEdge" - } - object FalseEdge extends CfgEdgeType { - override def toString: String = "FalseEdge" - } - object AlwaysEdge extends CfgEdgeType { - override def toString: String = "AlwaysEdge" - } - object CaseEdge extends CfgEdgeType { - override def toString: String = "CaseEdge" - } - - /** - * Create edges from all nodes of cfg's fringe to `node`. - * */ - def edgesFromFringeTo(cfg: Cfg, node: Option[nodes.CfgNode]): List[CfgEdge] = { - edgesFromFringeTo(cfg.fringe, node) - } - - /** - * Create edges from all nodes of cfg's fringe to `node`, ignoring fringe edge types - * and using `cfgEdgeType` instead. - * */ - def edgesFromFringeTo(cfg: Cfg, node: Option[nodes.CfgNode], cfgEdgeType: CfgEdgeType): List[CfgEdge] = { - edges(cfg.fringe.map(_._1), node, cfgEdgeType) - } - - /** - * Create edges from a list (node, cfgEdgeType) pairs to `node` - * */ - def edgesFromFringeTo(fringeElems: List[(nodes.CfgNode, CfgEdgeType)], node: Option[nodes.CfgNode]): List[CfgEdge] = { - fringeElems.flatMap { - case (sourceNode, cfgEdgeType) => - node.map { dstNode => - CfgEdge(sourceNode, dstNode, cfgEdgeType) - } - } - } - - /** - * Create edges of given type from a list of source nodes to a destination node - * */ - def edges(sources: List[nodes.CfgNode], - dstNode: Option[nodes.CfgNode], - cfgEdgeType: CfgEdgeType = AlwaysEdge): List[CfgEdge] = { - edgesToMultiple(sources, dstNode.toList, cfgEdgeType) - } - - def singleEdge(source: nodes.CfgNode, - destination: nodes.CfgNode, - cfgEdgeType: CfgEdgeType = AlwaysEdge): List[CfgEdge] = { - edgesToMultiple(List(source), List(destination), cfgEdgeType) - } - - /** - * Create edges of given type from all nodes in `sources` to `node`. - * */ - def edgesToMultiple(sources: List[nodes.CfgNode], - destinations: List[nodes.CfgNode], - cfgEdgeType: CfgEdgeType = AlwaysEdge): List[CfgEdge] = { - - sources.flatMap { l => - destinations.map { n => - CfgEdge(l, n, cfgEdgeType) - } - } - } - -} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/cfgcreation/CfgCreator.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/cfgcreation/CfgCreator.scala deleted file mode 100644 index a65619d..0000000 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/cfgcreation/CfgCreator.scala +++ /dev/null @@ -1,404 +0,0 @@ -package io.shiftleft.fuzzyc2cpg.passes.cfgcreation - -import io.shiftleft.codepropertygraph.generated.{EdgeTypes, Operators, nodes} -import io.shiftleft.codepropertygraph.generated.nodes.CfgNode -import io.shiftleft.fuzzyc2cpg.passes.cfgcreation.Cfg.CfgEdgeType -import io.shiftleft.passes.DiffGraph -import io.shiftleft.semanticcpg.language._ - -/** - * Translation of abstract syntax trees into control flow graphs - * - * The problem of translating an abstract syntax tree into a corresponding - * control flow graph can be formulated as a recursive problem in which - * sub trees of the syntax tree are translated and their corresponding - * control flow graphs are connected according to the control flow - * semantics of the root node. - * For example, consider the abstract syntax tree for an if-statement: - * - * ( if ) - * / \ - * (x < 10) (x += 1) - * / \ / \ - * x 10 x 1 - * - * This tree can be translated into a control flow graph, by translating - * the sub tree rooted in `x < 10` and that of `x += 1` and connecting - * their control flow graphs according to the semantics of `if`: - * - * [x < 10]---- - * |t f| - * [x +=1 ] | - * | - * The semantics of if dictate that the first sub tree to the left - * is a condition, which is connected to the CFG of the second sub - * tree - the body of the if statement - via a control flow edge with - * the `true` label (indicated in the illustration by `t`), and to the CFG - * of any follow-up code via a `false` edge (indicated by `f`). - * - * A problem that becomes immediately apparent in the illustration is that - * the result of translating a sub tree may leave us with edges for which - * a source node is known but the destination node depends on parents or - * siblings that were not considered in the translation. For example, we know - * that an outgoing edge from [x<10] must exist, but we do not yet know where - * it should lead. We refer to the set of nodes of the control flow graph with - * outgoing edges for which the destination node is yet to be determined as - * the "fringe" of the control flow graph. - */ -class CfgCreator(entryNode: nodes.Method) { - - import io.shiftleft.fuzzyc2cpg.passes.cfgcreation.Cfg._ - import CfgCreator._ - - /** - * Control flow graph definitions often feature a designated entry - * and exit node for each method. While these nodes are no-ops - * from a computational point of view, they are useful to - * guarantee that a method has exactly one entry and one exit. - * - * For the CPG-based control flow graph, we do not need to - * introduce fake entry and exit node. Instead, we can use the - * METHOD and METHOD_RETURN nodes as entry and exit nodes - * respectively. Note that METHOD_RETURN nodes are the nodes - * representing formal return parameters, of which there exists - * exactly one per method. - * */ - private val exitNode: nodes.MethodReturn = entryNode.methodReturn - - /** - * We return the CFG as a sequence of Diff Graphs that is - * calculated by first obtaining the CFG for the method - * and then resolving gotos. - * */ - def run(): Iterator[DiffGraph] = toDiffGraphs( - cfgForMethod(entryNode).withResolvedGotos().edges - ) - - private def toDiffGraphs(edges: List[CfgEdge]): Iterator[DiffGraph] = { - val diffGraph = DiffGraph.newBuilder - edges.foreach { edge => - // TODO we are ignoring edge.edgeType because the - // CFG spec doesn't define an edge type at the moment - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CFG) - } - Iterator(diffGraph).map(_.build) - } - - /** - * Conversion of a method to a CFG, showing the decomposition - * of the control flow graph generation problem into that of - * translating sub trees according to the node type. In the - * particular case of a method, the CFG is obtained by - * creating a CFG containing the single method node and - * a fringe containing the node and an outgoing AlwaysEdge, - * to the CFG obtained by translating child CFGs one by - * one and appending them. - * */ - private def cfgForMethod(node: nodes.Method): Cfg = - cfgForSingleNode(node) ++ cfgForChildren(node) - - /** - * For any single AST node, we can construct a CFG - * containing that single node by setting it as - * the entry node and placing it in the fringe. - * */ - private def cfgForSingleNode(node: nodes.CfgNode): Cfg = - Cfg(entryNode = Some(node), fringe = List((node, AlwaysEdge))) - - /** - * The CFG for all children is obtained by translating - * child ASTs one by one from left to right and appending - * them. - * */ - private def cfgForChildren(node: nodes.AstNode): Cfg = - node.astChildren.l.map(cfgFor).reduceOption((x, y) => x ++ y).getOrElse(Cfg.empty) - - /** - * This method dispatches AST nodes by type and calls - * corresponding conversion methods. - * */ - private def cfgFor(node: nodes.AstNode): Cfg = - node match { - case n: nodes.ControlStructure => - cfgForControlStructure(n) - case n: nodes.JumpTarget => - cfgForJumpTarget(n) - case actualRet: nodes.Return => cfgForReturn(actualRet) - case call: nodes.Call if call.name == Operators.logicalAnd => - cfgForAndExpression(call) - case call: nodes.Call if call.name == Operators.logicalOr => - cfgForOrExpression(call) - case call: nodes.Call if call.name == Operators.conditional => - cfgForConditionalExpression(call) - case (_: nodes.Call | _: nodes.Identifier | _: nodes.Literal | _: nodes.MethodReturn) => - cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[nodes.CfgNode]) - case _ => - cfgForChildren(node) - } - - /** - * A second layer of dispatching for control structures. This could - * as well be part of `cfgFor` and has only been placed into a - * separate function to increase readability. - * */ - private def cfgForControlStructure(node: nodes.ControlStructure): Cfg = - node.parserTypeName match { - case "BreakStatement" => - cfgForBreakStatement(node) - case "ContinueStatement" => - cfgForContinueStatement(node) - case "WhileStatement" => - cfgForWhileStatement(node) - case "DoStatement" => - cfgForDoStatement(node) - case "ForStatement" => - cfgForForStatement(node) - case "GotoStatement" => - cfgForGotoStatement(node) - case "IfStatement" => - cfgForIfStatement(node) - case "ElseStatement" => - cfgForChildren(node) - case "SwitchStatement" => - cfgForSwitchStatement(node) - case _ => - Cfg.empty - } - - /** - * The CFG for a break/continue statements contains only - * the break/continue statement as a single entry node. - * The fringe is empty, that is, appending - * another CFG to the break statement will - * not result in the creation of an edge from - * the break statement to the entry point - * of the other CFG. - * */ - private def cfgForBreakStatement(node: nodes.ControlStructure): Cfg = - Cfg(entryNode = Some(node), breaks = List(node)) - - private def cfgForContinueStatement(node: nodes.ControlStructure): Cfg = - Cfg(entryNode = Some(node), continues = List(node)) - - /** - * Jump targets ("labels") are included in the CFG. As these - * should be connected to the next appended CFG, we specify - * that the label node is both the entry node and the only - * node in the fringe. This is achieved by calling `cfgForSingleNode` - * on the label node. Just like for breaks and continues, we record - * labels. We store case/default labels separately from other labels, - * but that is not a relevant implementation detail. - * */ - private def cfgForJumpTarget(n: nodes.JumpTarget): Cfg = { - val labelName = n.name - val cfg = cfgForSingleNode(n) - if (labelName.startsWith("case") || labelName.startsWith("default")) { - cfg.copy(caseLabels = List(n)) - } else { - cfg.copy(labeledNodes = Map(labelName -> n)) - } - } - - /** - * A CFG for a goto statement is one containing the goto - * node as an entry node and an empty fringe. Moreover, we - * store the goto for dispatching with `withResolvedGotos` - * once the CFG for the entire method has been calculated. - * */ - private def cfgForGotoStatement(node: nodes.ControlStructure): Cfg = { - // TODO: the goto node should contain a field for the target so that - // we can avoid the brittle split/slice operation here - val target = node.code.split(" ").lastOption.map(x => x.slice(0, x.length - 1)) - target.map(t => Cfg(entryNode = Some(node), gotos = List((node, t)))).getOrElse(Cfg.empty) - } - - /** - * Return statements may contain expressions as return values, - * and therefore, the CFG for a return statement consists of - * the CFG for calculation of that expression, appended to - * a CFG containing only the return node, connected with - * a single edge to the method exit node. The fringe is - * empty. - * */ - private def cfgForReturn(actualRet: nodes.Return): Cfg = { - cfgForChildren(actualRet) ++ - Cfg(entryNode = Some(actualRet), edges = singleEdge(actualRet, exitNode), List()) - } - - /** - * The right hand side of a logical AND expression is only evaluated - * if the left hand side is true as the entire expression can only - * be true if both expressions are true. This is encoded in the - * corresponding control flow graph by creating control flow graphs - * for the left and right hand expressions and appending the two, - * where the fringe edge type of the left CFG is `TrueEdge`. - * */ - def cfgForAndExpression(call: nodes.Call): Cfg = { - val leftCfg = cfgFor(call.argument(1)) - val rightCfg = cfgFor(call.argument(2)) - val diffGraphs = edgesFromFringeTo(leftCfg, rightCfg.entryNode, TrueEdge) ++ leftCfg.edges ++ rightCfg.edges - Cfg(entryNode = leftCfg.entryNode, edges = diffGraphs, fringe = leftCfg.fringe ++ rightCfg.fringe) ++ cfgForSingleNode( - call) - } - - /** - * Same construction recipe as for the AND expression, just that the fringe edge type - * of the left CFG is `FalseEdge`. - * */ - def cfgForOrExpression(call: nodes.Call): Cfg = { - val leftCfg = cfgFor(call.argument(1)) - val rightCfg = cfgFor(call.argument(2)) - val diffGraphs = edgesFromFringeTo(leftCfg, rightCfg.entryNode, FalseEdge) ++ leftCfg.edges ++ rightCfg.edges - Cfg(entryNode = leftCfg.entryNode, edges = diffGraphs, fringe = leftCfg.fringe ++ rightCfg.fringe) ++ cfgForSingleNode( - call) - } - - /** - * A conditional expression is of the form `condition ? trueExpr ; falseExpr` - * We create the corresponding CFGs by creating CFGs for the three expressions - * and adding edges between them. The new entry node is the condition entry - * node. - * */ - private def cfgForConditionalExpression(call: nodes.Call): Cfg = { - val conditionCfg = cfgFor(call.argument(1)) - val trueCfg = cfgFor(call.argument(2)) - val falseCfg = cfgFor(call.argument(3)) - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode, TrueEdge) ++ - edgesFromFringeTo(conditionCfg, falseCfg.entryNode, FalseEdge) - - Cfg( - entryNode = conditionCfg.entryNode, - edges = conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges ++ diffGraphs, - fringe = trueCfg.fringe ++ falseCfg.fringe - ) ++ cfgForSingleNode(call) - } - - /** - * A for statement is of the form `for(initExpr; condition; loopExpr) body` - * and all four components may be empty. The sequence - * (condition - body - loopExpr) form the inner part of the loop - * and we calculate the corresponding CFG `innerCfg` so that it is no longer - * relevant which of these three actually exist and we still have an entry - * node for the loop and a fringe. - * */ - private def cfgForForStatement(node: nodes.ControlStructure): Cfg = { - val children = node.astChildren.l - val initExprCfg = children.find(_.order == 1).map(cfgFor).getOrElse(Cfg.empty) - val conditionCfg = children.find(_.order == 2).map(cfgFor).getOrElse(Cfg.empty) - val loopExprCfg = children.find(_.order == 3).map(cfgFor).getOrElse(Cfg.empty) - val bodyCfg = children.find(_.order == 4).map(cfgFor).getOrElse(Cfg.empty) - - val innerCfg = conditionCfg ++ bodyCfg ++ loopExprCfg - val entryNode = (initExprCfg ++ innerCfg).entryNode - - val newEdges = edgesFromFringeTo(initExprCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(innerCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, bodyCfg.entryNode, TrueEdge) ++ { - if (loopExprCfg.entryNode.isDefined) { - edges(bodyCfg.continues, loopExprCfg.entryNode) - } else { - edges(bodyCfg.continues, innerCfg.entryNode) - } - } - - Cfg( - entryNode = entryNode, - edges = newEdges ++ initExprCfg.edges ++ innerCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ bodyCfg.breaks.map((_, AlwaysEdge)) - ) - } - - /** - * A Do-Statement is of the form `do body while(condition)` where body may be empty. - * We again first calculate the inner CFG as bodyCfg ++ conditionCfg and then connect - * edges according to the semantics of do-while. - * */ - private def cfgForDoStatement(node: nodes.ControlStructure): Cfg = { - val bodyCfg = node.astChildren.filter(_.order(1)).headOption.map(cfgFor).getOrElse(Cfg.empty) - val conditionCfg = node.start.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val innerCfg = bodyCfg ++ conditionCfg - - val diffGraphs = - edges(bodyCfg.continues, conditionCfg.entryNode) ++ - edgesFromFringeTo(bodyCfg, conditionCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, innerCfg.entryNode, TrueEdge) - - Cfg( - entryNode = if (bodyCfg != Cfg.empty) { bodyCfg.entryNode } else { conditionCfg.entryNode }, - edges = diffGraphs ++ bodyCfg.edges ++ conditionCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ bodyCfg.breaks.map((_, AlwaysEdge)) - ) - } - - /** - * CFG creation for while statements of the form `while(condition) body` - * where body is optional. - * */ - private def cfgForWhileStatement(node: nodes.ControlStructure): Cfg = { - val conditionCfg = node.start.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val trueCfg = node.start.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ - edgesFromFringeTo(trueCfg, conditionCfg.entryNode) ++ - edges(trueCfg.continues, conditionCfg.entryNode) - - Cfg( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ trueCfg.breaks.map((_, AlwaysEdge)) - ) - } - - /** - * CFG creation for switch statements of the form `switch{ case $x: ... }`. - * */ - private def cfgForSwitchStatement(node: nodes.ControlStructure): Cfg = { - val conditionCfg = node.start.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val bodyCfg = node.start.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val diffGraphs = edgesToMultiple(conditionCfg.fringe.map(_._1), bodyCfg.caseLabels, CaseEdge) - - val hasDefaultCase = bodyCfg.caseLabels.exists(x => x.asInstanceOf[nodes.JumpTarget].name == "default") - - Cfg( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ bodyCfg.edges, - fringe = { if (!hasDefaultCase) { conditionCfg.fringe.withEdgeType(FalseEdge) } else { List() } } ++ bodyCfg.breaks - .map((_, AlwaysEdge)) ++ bodyCfg.fringe - ) - } - - /** - * CFG creation for if statements of the form `if(condition) body`, optionally - * followed by `else body2`. - * */ - private def cfgForIfStatement(node: nodes.ControlStructure): Cfg = { - val conditionCfg = node.start.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val trueCfg = node.start.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = node.start.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) - - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, falseCfg.entryNode) - - Cfg( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, - fringe = trueCfg.fringe ++ { - if (falseCfg.entryNode.isDefined) { - falseCfg.fringe - } else { - conditionCfg.fringe.withEdgeType(FalseEdge) - } - } - ) - } - -} - -object CfgCreator { - - implicit class FringeWrapper(fringe: List[(nodes.CfgNode, CfgEdgeType)]) { - def withEdgeType(edgeType: CfgEdgeType): List[(CfgNode, CfgEdgeType)] = { - fringe.map { case (x, _) => (x, edgeType) } - } - } - -} diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/Scope.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/scope/Scope.scala similarity index 94% rename from src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/Scope.scala rename to src/main/scala/io/shiftleft/fuzzyc2cpg/scope/Scope.scala index 875ab05..8bb47c4 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/Scope.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/scope/Scope.scala @@ -1,4 +1,4 @@ -package io.shiftleft.fuzzyc2cpg.passes.astcreation +package io.shiftleft.fuzzyc2cpg.scope /** * Handles the scope stack for tracking identifier to variable relation. diff --git a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/ScopeElement.scala b/src/main/scala/io/shiftleft/fuzzyc2cpg/scope/ScopeElement.scala similarity index 87% rename from src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/ScopeElement.scala rename to src/main/scala/io/shiftleft/fuzzyc2cpg/scope/ScopeElement.scala index f2d8745..1dd7bcb 100644 --- a/src/main/scala/io/shiftleft/fuzzyc2cpg/passes/astcreation/ScopeElement.scala +++ b/src/main/scala/io/shiftleft/fuzzyc2cpg/scope/ScopeElement.scala @@ -1,4 +1,4 @@ -package io.shiftleft.fuzzyc2cpg.passes.astcreation +package io.shiftleft.fuzzyc2cpg.scope /** * A single element of a scope stack. diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/AssignmentTests.java b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/AssignmentTests.java index beae548..6650886 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/AssignmentTests.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/AssignmentTests.java @@ -3,7 +3,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.Test; diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/ForLoopTests.java b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/ForLoopTests.java index 2e0c3eb..387b5d9 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/ForLoopTests.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/ForLoopTests.java @@ -2,7 +2,7 @@ import static org.junit.Assert.assertTrue; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.Test; diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCallTests.java b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCallTests.java index 61c3966..c81ec92 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCallTests.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCallTests.java @@ -2,7 +2,7 @@ import static org.junit.Assert.assertTrue; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.Test; diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCommentTests.java b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCommentTests.java index aa8fe25..3d2b136 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCommentTests.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionCommentTests.java @@ -2,11 +2,11 @@ import static org.junit.Assert.assertEquals; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.Test; import io.shiftleft.fuzzyc2cpg.FunctionParser; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; public class FunctionCommentTests extends FunctionParserTestBase { diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTest.java b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTest.java index 1b50545..2d95be2 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTest.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTest.java @@ -3,7 +3,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.Test; diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTestBase.java b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTestBase.java index 9686e24..05960f8 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTestBase.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/antlrparsers/functionparser/FunctionParserTestBase.java @@ -1,7 +1,6 @@ package io.shiftleft.fuzzyc2cpg.antlrparsers.functionparser; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrParserDriver; import org.antlr.v4.runtime.CharStream; import org.antlr.v4.runtime.CharStreams; import org.antlr.v4.runtime.CommonTokenStream; @@ -9,6 +8,8 @@ import io.shiftleft.fuzzyc2cpg.FunctionLexer; import io.shiftleft.fuzzyc2cpg.FunctionParser; +import io.shiftleft.fuzzyc2cpg.parser.AntlrParserDriver; +import io.shiftleft.fuzzyc2cpg.parser.TokenSubStream; import io.shiftleft.fuzzyc2cpg.parser.functions.AntlrCFunctionParserDriver; public class FunctionParserTestBase { diff --git a/src/test/java/io/shiftleft/fuzzyc2cpg/parsetreetoast/ModuleBuildersTest.java b/src/test/java/io/shiftleft/fuzzyc2cpg/parsetreetoast/ModuleBuildersTest.java index 15781fd..1ae2349 100644 --- a/src/test/java/io/shiftleft/fuzzyc2cpg/parsetreetoast/ModuleBuildersTest.java +++ b/src/test/java/io/shiftleft/fuzzyc2cpg/parsetreetoast/ModuleBuildersTest.java @@ -1,6 +1,5 @@ package io.shiftleft.fuzzyc2cpg.parsetreetoast; -import io.shiftleft.fuzzyc2cpg.passes.astcreation.AntlrCModuleParserDriver; import java.util.List; import org.antlr.v4.runtime.CharStream; @@ -17,6 +16,7 @@ import io.shiftleft.fuzzyc2cpg.ast.langc.functiondef.ParameterType; import io.shiftleft.fuzzyc2cpg.ast.statements.IdentifierDeclStatement; import io.shiftleft.fuzzyc2cpg.parser.TokenSubStream; +import io.shiftleft.fuzzyc2cpg.parser.modules.AntlrCModuleParserDriver; import static org.junit.Assert.*; diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/CpgTestFixture.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/CpgTestFixture.scala index 57c03a5..d7c26e4 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/CpgTestFixture.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/CpgTestFixture.scala @@ -2,24 +2,17 @@ package io.shiftleft.fuzzyc2cpg import gremlin.scala.GraphAsScala import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.fuzzyc2cpg.passes.{AstCreationPass, CMetaDataPass, CfgCreationPass, StubRemovalPass} -import io.shiftleft.passes.IntervalKeyPool -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.fuzzyc2cpg.output.inmemory.OutputModuleFactory case class CpgTestFixture(projectName: String) { - val cpg: Cpg = Cpg.emptyCpg - val dirName = String.format("src/test/resources/testcode/%s", projectName) - val keyPoolFile1 = new IntervalKeyPool(1001, 2000) - val cfgKeyPool = new IntervalKeyPool(2001, 3000) - val filenames = SourceFiles.determine(Set(dirName), Set(".c")) - - new CMetaDataPass(cpg).createAndApply() - new AstCreationPass(filenames, cpg, keyPoolFile1).createAndApply() - if (cpg.method.size > 0) { - new CfgCreationPass(cpg, cfgKeyPool).createAndApply() + val cpg: Cpg = { + val dirName = String.format("src/test/resources/testcode/%s", projectName) + val inmemoryOutputFactory = new OutputModuleFactory() + val fuzzyc2Cpg = new FuzzyC2Cpg(inmemoryOutputFactory) + fuzzyc2Cpg.runAndOutput(Set(dirName), Set(".c", ".cc", ".cpp", ".h", ".hpp")) + inmemoryOutputFactory.getInternalGraph } - new StubRemovalPass(cpg).createAndApply() def V = cpg.graph.asScala.V diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodDeclTest.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodDeclTest.scala index 3e632cb..f0c52fa 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodDeclTest.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodDeclTest.scala @@ -16,7 +16,7 @@ class MethodDeclTest extends WordSpec with Matchers { result.size shouldBe 1 val signature = result.head.property[String](NodeKeys.SIGNATURE.name).value - signature shouldBe "int add (int,int)" + signature shouldBe "int(int,int)" } } } diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodHeaderTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodHeaderTests.scala index c95e5c6..8134f12 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodHeaderTests.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodHeaderTests.scala @@ -14,7 +14,7 @@ class MethodHeaderTests extends WordSpec with Matchers { methods.size shouldBe 1 methods.head.value2(NodeKeys.IS_EXTERNAL) shouldBe false methods.head.value2(NodeKeys.FULL_NAME) shouldBe "foo" - methods.head.value2(NodeKeys.SIGNATURE) shouldBe "int foo (int,int)" + methods.head.value2(NodeKeys.SIGNATURE) shouldBe "int(int,int)" methods.head.value2(NodeKeys.LINE_NUMBER) shouldBe 1 methods.head.value2(NodeKeys.COLUMN_NUMBER) shouldBe 0 methods.head.value2(NodeKeys.LINE_NUMBER_END) shouldBe 3 diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodInternalLinkageTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodInternalLinkageTests.scala index 175119d..3370001 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodInternalLinkageTests.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/MethodInternalLinkageTests.scala @@ -59,19 +59,20 @@ class MethodInternalLinkageTests extends WordSpec with Matchers with TraversalUt parameterX.checkForSingle(NodeTypes.METHOD_PARAMETER_IN, NodeKeys.NAME, "x") } - "be correct for all identifiers x, y in method3" in { + "be correct for all indentifiers x, y in method3" in { val method = getMethod("method3") - val outerIdentifierX = method.expandAst().expandAst().filterOrder(3).expandAst(NodeTypes.IDENTIFIER) - outerIdentifierX.checkForSingle(NodeKeys.NAME, "x") - val parameterX = outerIdentifierX.expandRef() + + val outerIndentifierX = method.expandAst().expandAst().filterOrder(2).expandAst(NodeTypes.IDENTIFIER) + outerIndentifierX.checkForSingle(NodeKeys.NAME, "x") + val parameterX = outerIndentifierX.expandRef() parameterX.checkForSingle(NodeTypes.METHOD_PARAMETER_IN, NodeKeys.NAME, "x") - val expectedParameterX = method.expandAst(NodeTypes.METHOD_PARAMETER_IN) - expectedParameterX.checkForSingle(NodeKeys.NAME, "x") - parameterX shouldBe expectedParameterX + val expectedParamterX = method.expandAst(NodeTypes.METHOD_PARAMETER_IN) + expectedParamterX.checkForSingle(NodeKeys.NAME, "x") + parameterX shouldBe expectedParamterX - val outerIdentifierY = method.expandAst().expandAst().filterOrder(4).expandAst(NodeTypes.IDENTIFIER) - outerIdentifierY.checkForSingle(NodeKeys.NAME, "y") - val outerLocalY = outerIdentifierY.expandRef() + val outerIndentifierY = method.expandAst().expandAst().filterOrder(3).expandAst(NodeTypes.IDENTIFIER) + outerIndentifierY.checkForSingle(NodeKeys.NAME, "y") + val outerLocalY = outerIndentifierY.expandRef() outerLocalY.checkForSingle(NodeTypes.LOCAL, NodeKeys.NAME, "y") val expectedOuterLocalY = method.expandAst().expandAst(NodeTypes.LOCAL) expectedOuterLocalY.checkForSingle(NodeKeys.NAME, "y") @@ -79,7 +80,7 @@ class MethodInternalLinkageTests extends WordSpec with Matchers with TraversalUt val nestedBlock = method.expandAst().expandAst(NodeTypes.BLOCK) - val nestedIdentifierX = nestedBlock.expandAst().filterOrder(3).expandAst(NodeTypes.IDENTIFIER) + val nestedIdentifierX = nestedBlock.expandAst().filterOrder(1).expandAst(NodeTypes.IDENTIFIER) nestedIdentifierX.checkForSingle(NodeKeys.NAME, "x") val nestedLocalX = nestedIdentifierX.expandRef() nestedLocalX.checkForSingle(NodeTypes.LOCAL, NodeKeys.NAME, "x") @@ -87,7 +88,7 @@ class MethodInternalLinkageTests extends WordSpec with Matchers with TraversalUt expectedNestedLocalX.checkForSingle(NodeKeys.NAME, "x") nestedLocalX shouldBe expectedNestedLocalX - val nestedIdentifierY = nestedBlock.expandAst().filterOrder(4).expandAst(NodeTypes.IDENTIFIER) + val nestedIdentifierY = nestedBlock.expandAst().filterOrder(2).expandAst(NodeTypes.IDENTIFIER) nestedIdentifierY.checkForSingle(NodeKeys.NAME, "y") val nestedLocalY = nestedIdentifierY.expandRef() nestedLocalY.checkForSingle(NodeTypes.LOCAL, NodeKeys.NAME, "y") diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/StableOutputTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/StableOutputTests.scala index f627288..51cd216 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/StableOutputTests.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/StableOutputTests.scala @@ -1,5 +1,6 @@ package io.shiftleft.fuzzyc2cpg +import io.shiftleft.fuzzyc2cpg.output.inmemory.OutputModuleFactory import org.scalatest.{Matchers, WordSpec} import scala.jdk.CollectionConverters._ @@ -9,8 +10,10 @@ class StableOutputTests extends WordSpec with Matchers { def createNodeStrings(): String = { val projectName = "stableid" val dirName = String.format("src/test/resources/testcode/%s", projectName) - val fuzzyc2Cpg = new FuzzyC2Cpg() - val cpg = fuzzyc2Cpg.runAndOutput(Set(dirName), Set(".c", ".cc", ".cpp", ".h", ".hpp")) + val inmemoryOutputFactory = new OutputModuleFactory() + val fuzzyc2Cpg = new FuzzyC2Cpg(inmemoryOutputFactory) + fuzzyc2Cpg.runAndOutput(Set(dirName), Set(".c", ".cc", ".cpp", ".h", ".hpp")) + val cpg = inmemoryOutputFactory.getInternalGraph val nodes = cpg.graph.V().asScala.toList nodes.sortBy(_.id2()).map(x => x.label + ": " + x.propertyMap().asScala.toString).mkString("\n") } diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/astnew/AstToCpgTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/astnew/AstToCpgTests.scala new file mode 100644 index 0000000..624fa0d --- /dev/null +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/astnew/AstToCpgTests.scala @@ -0,0 +1,1097 @@ +package io.shiftleft.fuzzyc2cpg.astnew + +import io.shiftleft.OverflowDbTestInstance +import org.antlr.v4.runtime.{CharStreams, ParserRuleContext} +import org.scalatest.{Matchers, WordSpec} +import io.shiftleft.codepropertygraph.generated.{EdgeTypes, NodeKeys, NodeKeysOdb, NodeTypes, Operators} +import io.shiftleft.fuzzyc2cpg.{Global, ModuleLexer} +import io.shiftleft.fuzzyc2cpg.adapter.CpgAdapter +import io.shiftleft.fuzzyc2cpg.adapter.EdgeKind.EdgeKind +import io.shiftleft.fuzzyc2cpg.adapter.EdgeProperty.EdgeProperty +import io.shiftleft.fuzzyc2cpg.adapter.NodeKind.NodeKind +import io.shiftleft.fuzzyc2cpg.adapter.NodeProperty.NodeProperty +import io.shiftleft.fuzzyc2cpg.ast.{AstNode, AstNodeBuilder} +import io.shiftleft.fuzzyc2cpg.parser.modules.AntlrCModuleParserDriver +import io.shiftleft.fuzzyc2cpg.parser.{AntlrParserDriverObserver, TokenSubStream} +import overflowdb._ +import overflowdb.traversal._ + +class AstToCpgTests extends WordSpec with Matchers { + + private class GraphAdapter(graph: OdbGraph) extends CpgAdapter[Node, Node, OdbEdge, OdbEdge] { + override def createNodeBuilder(kind: NodeKind): Node = { + graph.addNode(kind.toString) + } + + override def createNode(vertex: Node, origAstNode: AstNode): Node = { + vertex + } + + override def createNode(vertex: Node): Node = { + vertex + } + + override def addNodeProperty(vertex: Node, property: NodeProperty, value: String): Unit = { + vertex.property(property.toString, value) + } + + override def addNodeProperty(vertex: Node, property: NodeProperty, value: Int): Unit = { + vertex.property(property.toString, value) + } + + override def addNodeProperty(vertex: Node, property: NodeProperty, value: Boolean): Unit = { + vertex.property(property.toString, value) + } + + override def addNodeProperty(vertex: Node, property: NodeProperty, value: List[String]): Unit = { + vertex.property(property.toString, value) + } + + override def createEdgeBuilder(dst: Node, src: Node, edgeKind: EdgeKind): OdbEdge = { + src.addEdge2(edgeKind.toString, dst) + } + + override def createEdge(edge: OdbEdge): OdbEdge = { + edge + } + + // Not used in test with this adapter. + override def addEdgeProperty(edgeBuilder: OdbEdge, property: EdgeProperty, value: String): Unit = ??? + override def mapNode(astNode: AstNode): Node = ??? + } + + private implicit class VertexListWrapper(vertexList: List[Node]) { + def expandAst(filterLabels: String*): List[Node] = { + if (filterLabels.nonEmpty) { + vertexList.flatMap(_.start.out(EdgeTypes.AST).hasLabel(filterLabels.head, filterLabels.tail: _*).l) + } else { + vertexList.flatMap(_.start.out(EdgeTypes.AST).l) + } + } + + def expandCondition: List[Node] = + vertexList.flatMap(_.start.out(EdgeTypes.CONDITION).l) + + def expandArgument: List[Node] = + vertexList.flatMap(_.start.out(EdgeTypes.ARGUMENT).l) + + def filterOrder(order: Int): List[Node] = { + vertexList.filter(_.property(NodeKeysOdb.ORDER) == order) + } + + def checkForSingle[T](propertyName: PropertyKey[T], value: T): Unit = { + vertexList.size shouldBe 1 + vertexList.head.property(propertyName) shouldBe value + } + + def checkForSingle(): Unit = { + vertexList.size shouldBe 1 + } + + def check[A](count: Int, mapFunc: Node => A, expectations: A*): Unit = { + vertexList.size shouldBe count + vertexList.map(mapFunc).toSet shouldBe expectations.toSet + } + + } + + class Fixture(code: String) { + + private class DriverObserver extends AntlrParserDriverObserver { + override def begin(): Unit = {} + + override def end(): Unit = {} + + override def startOfUnit(ctx: ParserRuleContext, filename: String): Unit = {} + + override def endOfUnit(ctx: ParserRuleContext, filename: String): Unit = {} + + override def processItem[T <: AstNode](node: T, + builderStack: java.util.Stack[AstNodeBuilder[_ <: AstNode]]): Unit = { + nodes = node :: nodes + } + } + + private var nodes = List[AstNode]() + + private val driver = new AntlrCModuleParserDriver() + driver.addObserver(new DriverObserver()) + + private val inputStream = CharStreams.fromString(code) + private val lex = new ModuleLexer(inputStream) + private val tokens = new TokenSubStream(lex) + + driver.parseAndWalkTokenStream(tokens) + + val graph: OdbGraph = OverflowDbTestInstance.create + private val astParentNode = graph.addNode("NAMESPACE_BLOCK") + protected val astParent = List(astParentNode) + private val cpgAdapter = new GraphAdapter(graph) + + val global = Global() + nodes.foreach { node => + val astToProtoConverter = new AstToCpgConverter(astParentNode, cpgAdapter, global) + astToProtoConverter.convert(node) + } + + def getMethod(name: String): List[Node] = + getVertices(name, NodeTypes.METHOD) + + def getTypeDecl(name: String): List[Node] = + getVertices(name, NodeTypes.TYPE_DECL) + + def getCall(name: String): List[Node] = + getVertices(name, NodeTypes.CALL) + + def getVertices(name: String, nodeType: String): List[Node] = { + val result = graph.nodes(nodeType).has(NodeKeysOdb.NAME -> name).toList + + result.size shouldBe 1 + result + } + } + + "Method AST layout" should { + "be correct for empty method" in new Fixture(""" + |void method(int x) { + |}" + """.stripMargin) { + val method = getMethod("method") + method.expandAst(NodeTypes.BLOCK).checkForSingle() + + method + .expandAst(NodeTypes.METHOD_RETURN) + .checkForSingle(NodeKeysOdb.TYPE_FULL_NAME, "void") + + method + .expandAst(NodeTypes.METHOD_PARAMETER_IN) + .checkForSingle(NodeKeysOdb.TYPE_FULL_NAME, "int") + } + + "be correct for decl assignment" in new Fixture(""" + |void method() { + | int local = 1; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val local = block.expandAst(NodeTypes.LOCAL) + local.checkForSingle(NodeKeysOdb.NAME, "local") + local.checkForSingle(NodeKeysOdb.TYPE_FULL_NAME, "int") + + val assignment = block.expandAst(NodeTypes.CALL) + assignment.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + + val arguments = assignment.expandAst() + arguments.check( + 2, + arg => + (arg.label, + arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.TYPE_FULL_NAME), + arg.property(NodeKeysOdb.ORDER), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)), + expectations = (NodeTypes.IDENTIFIER, "local", "int", 1, 1), + (NodeTypes.LITERAL, "1", "int", 2, 2) + ) + } + + "be correct for decl assignment with identifier on right hand side" in new Fixture(""" + |void method(int x) { + | int local = x; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val local = block.expandAst(NodeTypes.LOCAL) + local.checkForSingle(NodeKeysOdb.NAME, "local") + local.checkForSingle(NodeKeysOdb.TYPE_FULL_NAME, "int") + + val assignment = block.expandAst(NodeTypes.CALL) + assignment.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + + val arguments = assignment.expandAst() + arguments.check( + 2, + arg => + (arg.label, + arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.TYPE_FULL_NAME), + arg.property(NodeKeysOdb.ORDER), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)), + expectations = (NodeTypes.IDENTIFIER, "local", "int", 1, 1), + (NodeTypes.IDENTIFIER, "x", "int", 2, 2) + ) + } + + "be correct for decl assignment of multiple locals" in new Fixture(""" + |void method(int x, int y) { + | int local = x, local2 = y; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + block + .expandAst(NodeTypes.LOCAL) + .check( + 2, + local => (local.label, local.property(NodeKeysOdb.CODE), local.property(NodeKeysOdb.TYPE_FULL_NAME)), + expectations = (NodeTypes.LOCAL, "local", "int"), + (NodeTypes.LOCAL, "local2", "int") + ) + + val assignment1 = block.expandAst(NodeTypes.CALL).filterOrder(1) + assignment1.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + + val arguments1 = assignment1.expandAst() + arguments1.check( + 2, + arg => + (arg.label, + arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.TYPE_FULL_NAME), + arg.property(NodeKeysOdb.ORDER), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)), + expectations = (NodeTypes.IDENTIFIER, "local", "int", 1, 1), + (NodeTypes.IDENTIFIER, "x", "int", 2, 2) + ) + + val assignment2 = block.expandAst(NodeTypes.CALL).filterOrder(2) + assignment2.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + + val arguments2 = assignment2.expandAst() + arguments2.check( + 2, + arg => + (arg.label, + arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.TYPE_FULL_NAME), + arg.property(NodeKeysOdb.ORDER), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)), + expectations = (NodeTypes.IDENTIFIER, "local2", "int", 1, 1), + (NodeTypes.IDENTIFIER, "y", "int", 2, 2) + ) + } + + "be correct for nested expression" in new Fixture(""" + |void method() { + | int x; + | int y; + | int z; + | + | x = y + z; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + val locals = block.expandAst(NodeTypes.LOCAL) + locals.check(3, local => local.property(NodeKeysOdb.NAME), expectations = "x", "y", "z") + + val assignment = block.expandAst(NodeTypes.CALL) + assignment.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + + val rightHandSide = assignment.expandAst(NodeTypes.CALL).filterOrder(2) + rightHandSide.checkForSingle(NodeKeysOdb.NAME, Operators.addition) + + val arguments = rightHandSide.expandAst() + arguments.check( + 2, + arg => + (arg.label, + arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.ORDER), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)), + expectations = (NodeTypes.IDENTIFIER, "y", 1, 1), + (NodeTypes.IDENTIFIER, "z", 2, 2) + ) + } + + "be correct for nested block" in new Fixture(""" + |void method() { + | int x; + | { + | int y; + | } + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + val locals = block.expandAst(NodeTypes.LOCAL) + locals.checkForSingle(NodeKeysOdb.NAME, "x") + + val nestedBlock = block.expandAst(NodeTypes.BLOCK) + nestedBlock.checkForSingle() + val nestedLocals = nestedBlock.expandAst(NodeTypes.LOCAL) + nestedLocals.checkForSingle(NodeKeysOdb.NAME, "y") + } + + "be correct for while-loop" in new Fixture(""" + |void method(int x) { + | while (x < 1) { + | x += 1; + | } + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val whileStmt = block.expandAst(NodeTypes.CONTROL_STRUCTURE) + whileStmt.check(1, _.property(NodeKeysOdb.CODE), expectations = "while (x < 1)") + whileStmt.check(1, whileStmt => whileStmt.property(NodeKeysOdb.PARSER_TYPE_NAME), expectations = "WhileStatement") + + val condition = whileStmt.expandCondition + condition.checkForSingle(NodeKeysOdb.CODE, "x < 1") + + val lessThan = whileStmt.expandAst(NodeTypes.CALL) + lessThan.checkForSingle(NodeKeysOdb.NAME, Operators.lessThan) + + val whileBlock = whileStmt.expandAst(NodeTypes.BLOCK) + whileBlock.checkForSingle() + + val assignPlus = whileBlock.expandAst(NodeTypes.CALL) + assignPlus.filterOrder(1).checkForSingle(NodeKeysOdb.NAME, Operators.assignmentPlus) + } + + "be correct for if" in new Fixture(""" + |void method(int x) { + | int y; + | if (x > 0) { + | y = 0; + | } + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + val ifStmt = block.expandAst(NodeTypes.CONTROL_STRUCTURE) + ifStmt.check(1, _.property(NodeKeysOdb.PARSER_TYPE_NAME), expectations = "IfStatement") + + val condition = ifStmt.expandCondition + condition.checkForSingle(NodeKeysOdb.CODE, "x > 0") + + val greaterThan = ifStmt.expandAst(NodeTypes.CALL) + greaterThan.checkForSingle(NodeKeysOdb.NAME, Operators.greaterThan) + + val ifBlock = ifStmt.expandAst(NodeTypes.BLOCK) + ifBlock.checkForSingle() + + val assignment = ifBlock.expandAst(NodeTypes.CALL) + assignment.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + } + + "be correct for if-else" in new Fixture(""" + |void method(int x) { + | int y; + | if (x > 0) { + | y = 0; + | } else { + | y = 1; + | } + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + val ifStmt = block.expandAst(NodeTypes.CONTROL_STRUCTURE) + ifStmt.check(1, _.property(NodeKeysOdb.PARSER_TYPE_NAME), expectations = "IfStatement") + + val condition = ifStmt.expandCondition + condition.checkForSingle(NodeKeysOdb.CODE, "x > 0") + + val greaterThan = ifStmt.expandAst(NodeTypes.CALL) + greaterThan.checkForSingle(NodeKeysOdb.NAME, Operators.greaterThan) + + val ifBlock = ifStmt.expandAst(NodeTypes.BLOCK) + ifBlock.checkForSingle() + + val assignment = ifBlock.expandAst(NodeTypes.CALL) + assignment.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + + val elseStmt = ifStmt.expandAst(NodeTypes.CONTROL_STRUCTURE) + elseStmt.check(1, _.property(NodeKeysOdb.PARSER_TYPE_NAME), expectations = "ElseStatement") + elseStmt.check(1, _.property(NodeKeysOdb.CODE), "else") + + val elseBlock = elseStmt.expandAst(NodeTypes.BLOCK) + elseBlock.checkForSingle() + + val assignmentInElse = elseBlock.expandAst(NodeTypes.CALL) + assignmentInElse.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + } + + "be correct for conditional expression" in new Fixture( + """ + | void method() { + | int x = (foo == 1) ? bar : 0; + | } + """.stripMargin + ) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + val call = block.expandAst(NodeTypes.CALL) + val conditionalExpr = call.expandAst(NodeTypes.CALL) //formerly control structure + conditionalExpr.check(1, _.property(NodeKeysOdb.CODE), expectations = "(foo == 1) ? bar : 0") + conditionalExpr.check(1, _.property(NodeKeysOdb.NAME), expectations = ".conditionalExpression") + val params = conditionalExpr.expandAst() + params.check(3, + arg => (arg.property(NodeKeysOdb.ARGUMENT_INDEX), arg.property(NodeKeysOdb.CODE)), + expectations = (1, "foo == 1"), + (2, "bar"), + (3, "0")) + } + + "be correct for for-loop with multiple initializations" in new Fixture(""" + |void method(int x, int y) { + | for ( x = 0, y = 0; x < 1; x += 1) { + | int z = 0; + | } + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val forLoop = block.expandAst(NodeTypes.CONTROL_STRUCTURE) + forLoop.check(1, _.property(NodeKeysOdb.PARSER_TYPE_NAME), expectations = "ForStatement") + forLoop.check(1, _.property(NodeKeysOdb.CODE), expectations = "for ( x = 0, y = 0; x < 1; x += 1)") + + val conditionNode = forLoop.expandCondition + conditionNode.checkForSingle(NodeKeysOdb.CODE, "x < 1") + + val initBlock = forLoop.expandAst(NodeTypes.BLOCK).filterOrder(1) + initBlock.checkForSingle() + + val assignments = initBlock.expandAst(NodeTypes.CALL) + assignments.check(2, _.property(NodeKeysOdb.NAME), expectations = Operators.assignment) + + val condition = forLoop.expandAst(NodeTypes.CALL).filterOrder(2) + condition.checkForSingle(NodeKeysOdb.NAME, Operators.lessThan) + + val increment = forLoop.expandAst(NodeTypes.CALL).filterOrder(3) + increment.checkForSingle(NodeKeysOdb.NAME, Operators.assignmentPlus) + + val forBlock = forLoop.expandAst(NodeTypes.BLOCK).filterOrder(4) + forBlock.checkForSingle() + } + + "be correct for unary expression '+'" in new Fixture(""" + |void method(int x) { + | +x; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val plusCall = block.expandAst(NodeTypes.CALL) + plusCall.checkForSingle(NodeKeysOdb.NAME, Operators.plus) + + val identifierX = plusCall.expandAst(NodeTypes.IDENTIFIER) + identifierX.checkForSingle(NodeKeysOdb.NAME, "x") + } + + "be correct for unary expression '++'" in new Fixture(""" + |void method(int x) { + | ++x; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val plusCall = block.expandAst(NodeTypes.CALL) + plusCall.checkForSingle(NodeKeysOdb.NAME, Operators.preIncrement) + + val identifierX = plusCall.expandAst(NodeTypes.IDENTIFIER) + identifierX.checkForSingle(NodeKeysOdb.NAME, "x") + } + + "be correct for call expression" in new Fixture(""" + |void method(int x) { + | foo(x); + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val call = block.expandAst(NodeTypes.CALL) + call.checkForSingle(NodeKeysOdb.NAME, "foo") + + val argumentX = call.expandAst(NodeTypes.IDENTIFIER) + argumentX.checkForSingle(NodeKeysOdb.NAME, "x") + } + + "be correct for pointer call expression" in new Fixture(""" + |void method(int x) { + | (*funcPointer)(x); + |} + """.stripMargin) {} + + "be correct for member access" in new Fixture(""" + |void method(struct someUndefinedStruct x) { + | x.a; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val fieldAccess = block.expandAst(NodeTypes.CALL) + fieldAccess.checkForSingle(NodeKeysOdb.NAME, Operators.fieldAccess) + + val arguments = fieldAccess.expandAst(NodeTypes.IDENTIFIER) + arguments.check(1, arg => { + (arg.property(NodeKeysOdb.NAME), arg.property(NodeKeysOdb.ARGUMENT_INDEX)) + }, expectations = ("x", 1)) + fieldAccess + .expandAst(NodeTypes.FIELD_IDENTIFIER) + .check(1, arg => { + (arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.CANONICAL_NAME), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)) + }, expectations = ("a", "a", 2)) + + } + + "be correct for indirect member access" in new Fixture(""" + |void method(struct someUndefinedStruct *x) { + | x->a; + |} + """.stripMargin) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val fieldAccess = block.expandAst(NodeTypes.CALL) + fieldAccess.checkForSingle(NodeKeysOdb.NAME, Operators.indirectFieldAccess) + + val arguments = fieldAccess.expandAst(NodeTypes.IDENTIFIER) + arguments.check(1, arg => { + (arg.property(NodeKeysOdb.NAME), arg.property(NodeKeysOdb.ARGUMENT_INDEX)) + }, expectations = ("x", 1)) + fieldAccess + .expandAst(NodeTypes.FIELD_IDENTIFIER) + .check(1, arg => { + (arg.property(NodeKeysOdb.CODE), + arg.property(NodeKeysOdb.CANONICAL_NAME), + arg.property(NodeKeysOdb.ARGUMENT_INDEX)) + }, expectations = ("a", "a", 2)) + } + + "be correct for sizeof operator on identifier with brackets" in new Fixture( + """ + |void method() { + | int a; + | sizeof(a); + |} + """.stripMargin + ) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val sizeof = block.expandAst(NodeTypes.CALL) + sizeof.checkForSingle(NodeKeysOdb.NAME, Operators.sizeOf) + + val arguments = sizeof.expandAst(NodeTypes.IDENTIFIER) + arguments.checkForSingle(NodeKeysOdb.NAME, "a") + arguments.checkForSingle(NodeKeysOdb.ARGUMENT_INDEX, new Integer(1)) + } + + "be correct for sizeof operator on identifier without brackets" in new Fixture( + """ + |void method() { + | int a; + | sizeof a ; + |} + """.stripMargin + ) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val sizeof = block.expandAst(NodeTypes.CALL) + sizeof.checkForSingle(NodeKeysOdb.NAME, Operators.sizeOf) + + val arguments = sizeof.expandAst(NodeTypes.IDENTIFIER) + arguments.checkForSingle(NodeKeysOdb.NAME, "a") + arguments.checkForSingle(NodeKeysOdb.ARGUMENT_INDEX, new Integer(1)) + } + + "be correct for sizeof operator on type" in new Fixture( + """ + |void method() { + | sizeof(int); + |} + """.stripMargin + ) { + val method = getMethod("method") + val block = method.expandAst(NodeTypes.BLOCK) + block.checkForSingle() + + val sizeof = block.expandAst(NodeTypes.CALL) + sizeof.checkForSingle(NodeKeysOdb.NAME, Operators.sizeOf) + + // For us it is undecidable whether "int" is a type or an Identifier + // Thus the implementation always goes for Identifier which we encode + // here in the tests. + val arguments = sizeof.expandAst(NodeTypes.IDENTIFIER) + arguments.checkForSingle(NodeKeysOdb.NAME, "int") + arguments.checkForSingle(NodeKeysOdb.ARGUMENT_INDEX, new Integer(1)) + } + } + + "Structural AST layout" should { + "be correct for empty method" in new Fixture(""" + | void method() { + | }; + """.stripMargin) { + val method = getMethod("method") + method.checkForSingle() + + astParent.expandAst(NodeTypes.METHOD) shouldBe method + } + + "be correct for empty named struct" in new Fixture(""" + | struct foo { + | }; + """.stripMargin) { + val typeDecl = getTypeDecl("foo") + typeDecl.checkForSingle() + + astParent.expandAst(NodeTypes.TYPE_DECL) shouldBe typeDecl + } + + "be correct for named struct with single field" in new Fixture(""" + | struct foo { + | int x; + | }; + """.stripMargin) { + val typeDecl = getTypeDecl("foo") + typeDecl.checkForSingle() + val member = typeDecl.expandAst(NodeTypes.MEMBER) + member.checkForSingle(NodeKeysOdb.CODE, "x") + member.checkForSingle(NodeKeysOdb.NAME, "x") + member.checkForSingle(NodeKeysOdb.TYPE_FULL_NAME, "int") + } + + "be correct for named struct with multiple fields" in new Fixture(""" + | struct foo { + | int x; + | int y; + | int z; + | }; + """.stripMargin) { + val typeDecl = getTypeDecl("foo") + typeDecl.checkForSingle() + val member = typeDecl.expandAst(NodeTypes.MEMBER) + member.check(3, member => member.property(NodeKeysOdb.CODE), expectations = "x", "y", "z") + } + + "be correct for named struct with nested struct" in new Fixture(""" + | struct foo { + | int x; + | struct bar { + | int y; + | struct foo2 { + | int z; + | }; + | }; + | }; + """.stripMargin) { + val typeDeclFoo = getTypeDecl("foo") + typeDeclFoo.checkForSingle() + val memberFoo = typeDeclFoo.expandAst(NodeTypes.MEMBER) + memberFoo.checkForSingle(NodeKeysOdb.CODE, "x") + + val typeDeclBar = typeDeclFoo.expandAst(NodeTypes.TYPE_DECL) + typeDeclBar.checkForSingle(NodeKeysOdb.FULL_NAME, "bar") + val memberBar = typeDeclBar.expandAst(NodeTypes.MEMBER) + memberBar.checkForSingle(NodeKeysOdb.CODE, "y") + + val typeDeclFoo2 = typeDeclBar.expandAst(NodeTypes.TYPE_DECL) + typeDeclFoo2.checkForSingle(NodeKeysOdb.FULL_NAME, "foo2") + val memberFoo2 = typeDeclFoo2.expandAst(NodeTypes.MEMBER) + memberFoo2.checkForSingle(NodeKeysOdb.CODE, "z") + } + + "be correct for typedef" in new Fixture( + """ + |typedef struct foo { + |} abc; + """.stripMargin + ) { + val aliasTypeDecl = getTypeDecl("abc") + + aliasTypeDecl.checkForSingle(NodeKeysOdb.FULL_NAME, "abc") + aliasTypeDecl.checkForSingle(NodeKeysOdb.ALIAS_TYPE_FULL_NAME, "foo") + } + + "be correct for single inheritance" in new Fixture( + """ + |class Base {public: int i;}; + |class Derived : public Base{ + |public: + | char x; + | int method(){return i;}; + |}; + """.stripMargin + ) { + + val derivedL = getTypeDecl("Derived") + derivedL.checkForSingle() + + val derived = derivedL.head + derived.value[List[String]](NodeKeys.INHERITS_FROM_TYPE_FULL_NAME.name) shouldBe List("Base") + } + + "be correct for multiple inheritance" in new Fixture( + """ + |class OneBase {public: int i;}; + |class TwoBase {public: int j;}; + | + |class Derived : public OneBase, protected TwoBase{ + |public: + | char x; + | int method(){return i;}; + |}; + """.stripMargin + ) { + + val derivedL = getTypeDecl("Derived") + derivedL.checkForSingle() + + val derived = derivedL.head + derived.value[List[String]](NodeKeys.INHERITS_FROM_TYPE_FULL_NAME.name) shouldBe List("OneBase", "TwoBase") + } + + "be correct for method calls" in new Fixture( + """ + |void foo(int x) { + | bar(x); + |} + |""".stripMargin + ) { + val call = getCall("bar") + call.checkForSingle() + + val args = call.expandArgument + args.checkForSingle(NodeKeysOdb.CODE, "x") + } + + "be correct for method returns" in new Fixture( + """ + |void double(int x) { + | return x * 2; + |} + |""".stripMargin + ) { + val method = getMethod("double") + method.checkForSingle() + + val methodBody = method.expandAst(NodeTypes.BLOCK) + methodBody.checkForSingle() + + val methodReturn = methodBody.expandAst(NodeTypes.RETURN) + methodReturn.checkForSingle() + + val args = methodReturn.expandArgument + args.checkForSingle(NodeKeysOdb.CODE, "x * 2") + } + + "be correct for binary method calls" in new Fixture( + """ + |void double(int x) { + | return x * 2; + |} + |""".stripMargin + ) { + val call = getCall(".multiplication") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "x", "2") + } + + "be correct for unary method calls" in new Fixture( + """ + |bool invert(bool b) { + | return !b; + |} + |""".stripMargin + ) { + val call = getCall(".logicalNot") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.checkForSingle(NodeKeysOdb.CODE, "b") + } + + "be correct for post increment method calls" in new Fixture( + """ + |int foo(int x) { + | int sub = x--; + | int pos = x++; + | return pos; + |} + |""".stripMargin + ) { + val call = getCall(".postIncrement") + call.checkForSingle() + val callArgs = call.expandArgument + callArgs.checkForSingle(NodeKeysOdb.CODE, "x") + + val callDec = getCall(".postDecrement") + callDec.checkForSingle() + val callArgsDec = callDec.expandArgument + callArgsDec.checkForSingle(NodeKeysOdb.CODE, "x") + } + + "be correct for conditional expressions containing calls" in new Fixture( + """ + |int abs(int x) { + | return x > 0 ? x : -x; + |} + |""".stripMargin + ) { + val call = getCall(".conditionalExpression") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.check(3, x => x.property(NodeKeysOdb.CODE), "x > 0", "x", "-x") + } + + "be correct for sizeof expressions" in new Fixture( + """ + |size_t int_size() { + | return sizeof(int); + |} + |""".stripMargin + ) { + val call = getCall(".sizeOf") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.checkForSingle(NodeKeysOdb.CODE, "int") + } + + "be correct for label" in new Fixture("foo() { label: }") { + val jumpTarget = getVertices("label", NodeTypes.JUMP_TARGET) + jumpTarget.checkForSingle(NodeKeysOdb.CODE, "label:") + } + + "be correct for array indexing" in new Fixture( + """ + |int head(int x[]) { + | return x[0]; + |} + |""".stripMargin + ) { + val call = getCall(".indirectIndexAccess") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "x", "0") + } + + "be correct for type casts" in new Fixture( + """ + |int trunc(long x) { + | return (int) x; + |} + |""".stripMargin + ) { + val call = getCall(".cast") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "int", "x") + } + + "be correct for member accesses" in new Fixture( + """ + |int trunc(Foo x) { + | return x.count; + |} + |""".stripMargin + ) { + val call = getCall(".fieldAccess") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "x", "count") + callArgs.check(2, x => x.label(), NodeTypes.IDENTIFIER, NodeTypes.FIELD_IDENTIFIER) + callArgs.check(2, x => { + if (x.label() == NodeTypes.FIELD_IDENTIFIER) { x.property(NodeKeysOdb.CANONICAL_NAME) } else { "" } + }, "", "count") + } + + "be correct for indirect member accesses" in new Fixture( + """ + |int trunc(Foo* x) { + | return x->count; + |} + |""".stripMargin + ) { + val call = getCall(".indirectFieldAccess") + call.checkForSingle() + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "x", "count") + callArgs.check(2, x => x.label(), NodeTypes.IDENTIFIER, NodeTypes.FIELD_IDENTIFIER) + callArgs.check(2, x => { + if (x.label() == NodeTypes.FIELD_IDENTIFIER) { x.property(NodeKeysOdb.CANONICAL_NAME) } else { "" } + }, "", "count") + } + + "be correct for 'new' array" in new Fixture( + """ + |int[] alloc(int n) { + | int[] arr = new int[n]; + | return arr; + |} + |""".stripMargin + ) { + val call = getCall(".new") + call.checkForSingle(NodeKeysOdb.CODE, "new int[n]") + + val callArgs = call.expandArgument + callArgs.check(1, x => x.property(NodeKeysOdb.CODE), "int") + } + + "be correct for 'new' object" in new Fixture( + """ + |Foo* alloc(int n) { + | Foo* foo = new Foo(n, 42); + | return foo; + |} + |""".stripMargin + ) { + val call = getCall(".new") + call.checkForSingle(NodeKeysOdb.CODE, "new Foo(n, 42)") + + val callArgs = call.expandArgument + callArgs.check(1, x => x.property(NodeKeysOdb.CODE), "Foo") + } + + "be correct for simple 'delete'" in new Fixture( + """ + |int delete_number(int* n) { + | delete n; + |} + |""".stripMargin + ) { + val call = getCall(".delete") + call.checkForSingle(NodeKeysOdb.CODE, "delete n") + + val callArgs = call.expandArgument + callArgs.check(1, x => x.property(NodeKeysOdb.CODE), "n") + } + + "be correct for array 'delete'" in new Fixture( + """ + |void delete_number(int n[]) { + | delete[] n; + |} + |""".stripMargin + ) { + val call = getCall(".delete") + call.checkForSingle(NodeKeysOdb.CODE, "delete[] n") + + val callArgs = call.expandArgument + callArgs.check(1, x => x.property(NodeKeysOdb.CODE), "n") + } + + "be correct for const_cast" in new Fixture( + """ + |void foo() { + | int y = const_cast(n); + | return; + |} + |""".stripMargin + ) { + val call = getCall(".cast") + call.checkForSingle(NodeKeysOdb.CODE, "const_cast(n)") + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "int", "n") + } + + "be correct for static_cast" in new Fixture( + """ + |void foo() { + | int y = static_cast(n); + | return; + |} + |""".stripMargin + ) { + val call = getCall(".cast") + call.checkForSingle(NodeKeysOdb.CODE, "static_cast(n)") + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "int", "n") + } + + "be correct for dynamic_cast" in new Fixture( + """ + |void foo() { + | int y = dynamic_cast(n); + | return; + |} + |""".stripMargin + ) { + val call = getCall(".cast") + call.checkForSingle(NodeKeysOdb.CODE, "dynamic_cast(n)") + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "int", "n") + } + + "be correct for reinterpret_cast" in new Fixture( + """ + |void foo() { + | int y = reinterpret_cast(n); + | return; + |} + |""".stripMargin + ) { + val call = getCall(".cast") + call.checkForSingle(NodeKeysOdb.CODE, "reinterpret_cast(n)") + + val callArgs = call.expandArgument + callArgs.check(2, x => x.property(NodeKeysOdb.CODE), "int", "n") + } + } + + "AST" should { + "have correct line number for method content" in new Fixture(""" + | + | + | + | + | void method(int x) { + | + | x = 1; + | } + """.stripMargin) { + val method = getMethod("method") + method.checkForSingle(NodeKeysOdb.LINE_NUMBER, 6: Integer) + + val block = method.expandAst(NodeTypes.BLOCK) + + val assignment = block.expandAst(NodeTypes.CALL) + assignment.checkForSingle(NodeKeysOdb.NAME, Operators.assignment) + assignment.checkForSingle(NodeKeysOdb.LINE_NUMBER, 8: Integer) + } + } +} diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/cfg/AstToCfgTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/cfg/AstToCfgTests.scala new file mode 100644 index 0000000..5ee74e3 --- /dev/null +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/cfg/AstToCfgTests.scala @@ -0,0 +1,494 @@ +package io.shiftleft.fuzzyc2cpg.cfg + +import io.shiftleft.fuzzyc2cpg.adapter.EdgeKind.EdgeKind +import io.shiftleft.fuzzyc2cpg.adapter.EdgeProperty.EdgeProperty +import io.shiftleft.fuzzyc2cpg.adapter.NodeKind.NodeKind +import io.shiftleft.fuzzyc2cpg.adapter.NodeProperty.NodeProperty +import io.shiftleft.fuzzyc2cpg.adapter.{AlwaysEdge, CaseEdge, CfgEdgeType, CpgAdapter, FalseEdge, TrueEdge} +import io.shiftleft.fuzzyc2cpg.ast.AstNode +import io.shiftleft.fuzzyc2cpg.parsetreetoast.FunctionContentTestUtil +import org.scalatest.{Matchers, WordSpec} + +class AstToCfgTests extends WordSpec with Matchers { + private case class CfgNodeEdgePairBuilder(dstCfgNode: CfgNode, srcCfgNode: CfgNode, var cfgEdgeType: String) + + private case class CfgNodeEdgePair(cfgNode: CfgNode, cfgEdgeType: String) { + override def toString: String = { + s"$cfgEdgeType ==> ${cfgNode.code}" + } + } + private class CfgNode(val code: String, var successors: Set[CfgNodeEdgePair] = Set()) { + override def toString: String = { + s"$code === ${successors.mkString(", ")}" + } + } + + private class GraphAdapter extends CpgAdapter[CfgNode, CfgNode, CfgNodeEdgePairBuilder, CfgNodeEdgePair] { + private var mapping = Map[AstNode, CfgNode]() + var codeToCfgNode = Map[String, CfgNode]() + + override def mapNode(astNode: AstNode): CfgNode = { + if (mapping.contains(astNode)) { + mapping(astNode) + } else { + val cfgNode = new CfgNode(astNode.getEscapedCodeStr) + mapping += astNode -> cfgNode + codeToCfgNode += astNode.getEscapedCodeStr -> cfgNode + cfgNode + } + } + + // Not used in test with this adapter. + override def createNodeBuilder(kind: NodeKind): CfgNode = ??? + override def createNode(nodeBuilder: CfgNode): CfgNode = ??? + override def createNode(nodeBuilder: CfgNode, origAstNode: AstNode): CfgNode = ??? + override def addNodeProperty(nodeBuilder: CfgNode, property: NodeProperty, value: String): Unit = ??? + override def addNodeProperty(nodeBuilder: CfgNode, property: NodeProperty, value: Int): Unit = ??? + override def addNodeProperty(nodeBuilder: CfgNode, property: NodeProperty, value: Boolean): Unit = ??? + override def addNodeProperty(nodeBuilder: CfgNode, property: NodeProperty, value: List[String]): Unit = ??? + + override def createEdgeBuilder(dst: CfgNode, src: CfgNode, edgeKind: EdgeKind): CfgNodeEdgePairBuilder = { + if (src.successors.exists(_.cfgNode == dst)) { + throw new RuntimeException("Found duplicate edge.") + } + CfgNodeEdgePairBuilder(dst, src, null) + } + + override def createEdge(edgeBuilder: CfgNodeEdgePairBuilder): CfgNodeEdgePair = { + val newEdge = CfgNodeEdgePair(edgeBuilder.dstCfgNode, edgeBuilder.cfgEdgeType) + edgeBuilder.srcCfgNode.successors = edgeBuilder.srcCfgNode.successors + newEdge + newEdge + } + + override def addEdgeProperty(edgeBuilder: CfgNodeEdgePairBuilder, property: EdgeProperty, value: String): Unit = { + edgeBuilder.cfgEdgeType = value + } + } + + private class Fixture(code: String) { + private val astRoot = FunctionContentTestUtil.parseAndWalk(code) + private val entry = new CfgNode("ENTRY") + private val exit = new CfgNode("EXIT") + + private val adapter = new GraphAdapter() + private val astToCfgConverter = new AstToCfgConverter(entry, exit, adapter) + astToCfgConverter.convert(astRoot) + + private var codeToCpgNode = adapter.codeToCfgNode + codeToCpgNode += entry.code -> entry + codeToCpgNode += exit.code -> exit + + def expected(pairs: (String, CfgEdgeType)*): Set[CfgNodeEdgePair] = { + pairs.map { + case (code, cfgEdgeType) => + CfgNodeEdgePair(codeToCpgNode(code), cfgEdgeType.toString) + }.toSet + } + + def succOf(code: String): Set[CfgNodeEdgePair] = { + codeToCpgNode(code).successors + } + } + + "Cfg" should { + "be correct for decl statement with assignment" in + new Fixture("int x = 1;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x = 1", AlwaysEdge)) + succOf("x = 1") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for nested expression" in + new Fixture("x = y + 1;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("y + 1", AlwaysEdge)) + succOf("y + 1") shouldBe expected(("x = y + 1", AlwaysEdge)) + succOf("x = y + 1") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for return statement" in + new Fixture("return x;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("return x;", AlwaysEdge)) + succOf("return x;") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for consecutive return statements" in + new Fixture("return x; return y;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("return x;", AlwaysEdge)) + succOf("y") shouldBe expected(("return y;", AlwaysEdge)) + succOf("return x;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("return y;") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for void return statement" in + new Fixture("return;") { + succOf("ENTRY") shouldBe expected(("return;", AlwaysEdge)) + succOf("return;") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for call expression" in + new Fixture("foo(a + 1, b);") { + succOf("ENTRY") shouldBe expected(("a", AlwaysEdge)) + succOf("a") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("a + 1", AlwaysEdge)) + succOf("a + 1") shouldBe expected(("b", AlwaysEdge)) + succOf("b") shouldBe expected(("foo(a + 1, b)", AlwaysEdge)) + succOf("foo(a + 1, b)") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for unary expression '+'" in + new Fixture("+x;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("+x", AlwaysEdge)) + succOf("+x") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for unary expression '++'" in + new Fixture("++x;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("++x", AlwaysEdge)) + succOf("++x") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for conditional expression" in + new Fixture("x ? y : z;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", TrueEdge), ("z", FalseEdge)) + succOf("y") shouldBe expected(("x ? y : z", AlwaysEdge)) + succOf("z") shouldBe expected(("x ? y : z", AlwaysEdge)) + succOf("x ? y : z") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for short-circuit AND expression" in + // TODO: Broken by supporting move params? + new Fixture("x && y;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", TrueEdge), ("x && y", FalseEdge)) + succOf("y") shouldBe expected(("x && y", AlwaysEdge)) + succOf("x && y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for short-circuit OR expression" in + new Fixture("x || y;") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", FalseEdge), ("x || y", TrueEdge)) + succOf("y") shouldBe expected(("x || y", AlwaysEdge)) + succOf("x || y") shouldBe expected(("EXIT", AlwaysEdge)) + } + } + + "Cfg for while-loop" should { + "be correct" in + new Fixture("while (x < 1) { y = 2; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) + succOf("x < 1") shouldBe expected(("y", TrueEdge), ("EXIT", FalseEdge)) + succOf("y") shouldBe expected(("2", AlwaysEdge)) + succOf("2") shouldBe expected(("y = 2", AlwaysEdge)) + succOf("y = 2") shouldBe expected(("x", AlwaysEdge)) + } + + "be correct with break" in + new Fixture("while (x < 1) { break; y; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) + succOf("x < 1") shouldBe expected(("break;", TrueEdge), ("EXIT", FalseEdge)) + succOf("break;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("y") shouldBe expected(("x", AlwaysEdge)) + } + + "be correct with continue" in + new Fixture("while (x < 1) { continue; y; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) + succOf("x < 1") shouldBe expected(("continue;", TrueEdge), ("EXIT", FalseEdge)) + succOf("continue;") shouldBe expected(("x", AlwaysEdge)) + succOf("y") shouldBe expected(("x", AlwaysEdge)) + } + + "be correct with nested while-loop" in + new Fixture("while (x) { while (y) { z; }}") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", TrueEdge), ("EXIT", FalseEdge)) + succOf("y") shouldBe expected(("z", TrueEdge), ("x", FalseEdge)) + succOf("z") shouldBe expected(("y", AlwaysEdge)) + } + } + + "Cfg for do-while-loop" should { + "be correct" in + new Fixture("do { y = 2; } while (x < 1);") { + succOf("ENTRY") shouldBe expected(("y", AlwaysEdge)) + succOf("ENTRY") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("2", AlwaysEdge)) + succOf("2") shouldBe expected(("y = 2", AlwaysEdge)) + succOf("y = 2") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) + succOf("x < 1") shouldBe expected(("y", TrueEdge), ("EXIT", FalseEdge)) + } + + "be correct with break" in + new Fixture("do { break; y; } while (x < 1);") { + succOf("ENTRY") shouldBe expected(("break;", AlwaysEdge)) + succOf("break;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("y") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) + succOf("x < 1") shouldBe expected(("break;", TrueEdge), ("EXIT", FalseEdge)) + } + + "be correct with continue" in + new Fixture("do { continue; y; } while (x < 1);") { + succOf("ENTRY") shouldBe expected(("continue;", AlwaysEdge)) + succOf("continue;") shouldBe expected(("x", AlwaysEdge)) + succOf("y") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x < 1", AlwaysEdge)) + succOf("x < 1") shouldBe expected(("continue;", TrueEdge), ("EXIT", FalseEdge)) + } + + "be correct with nested do-while-loop" in + new Fixture("do { do { x; } while (y); } while (z);") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("x", TrueEdge), ("z", FalseEdge)) + succOf("z") shouldBe expected(("x", TrueEdge), ("EXIT", FalseEdge)) + } + } + + "Cfg for for-loop" should { + "be correct" in + new Fixture("for (x = 0; y < 1; z += 2) { a = 3; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("0", AlwaysEdge)) + succOf("0") shouldBe expected(("x = 0", AlwaysEdge)) + succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) + succOf("y < 1") shouldBe expected(("a", TrueEdge), ("EXIT", FalseEdge)) + succOf("a") shouldBe expected(("3", AlwaysEdge)) + succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) + succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) + succOf("z") shouldBe expected(("2", AlwaysEdge)) + succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) + succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + } + + "be correct with break" in + new Fixture("for (x = 0; y < 1; z += 2) { break; a = 3; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("0", AlwaysEdge)) + succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) + succOf("y < 1") shouldBe expected(("break;", TrueEdge), ("EXIT", FalseEdge)) + succOf("break;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("a") shouldBe expected(("3", AlwaysEdge)) + succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) + succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) + succOf("z") shouldBe expected(("2", AlwaysEdge)) + succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) + succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + } + + "be correct with continue" in + new Fixture("for (x = 0; y < 1; z += 2) { continue; a = 3; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("0", AlwaysEdge)) + succOf("0") shouldBe expected(("x = 0", AlwaysEdge)) + succOf("x = 0") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("y < 1", AlwaysEdge)) + succOf("y < 1") shouldBe expected(("continue;", TrueEdge), ("EXIT", FalseEdge)) + succOf("continue;") shouldBe expected(("z", AlwaysEdge)) + succOf("a") shouldBe expected(("3", AlwaysEdge)) + succOf("3") shouldBe expected(("a = 3", AlwaysEdge)) + succOf("a = 3") shouldBe expected(("z", AlwaysEdge)) + succOf("z") shouldBe expected(("2", AlwaysEdge)) + succOf("2") shouldBe expected(("z += 2", AlwaysEdge)) + succOf("z += 2") shouldBe expected(("y", AlwaysEdge)) + } + + "be correct with nested for-loop" in + new Fixture("for (x; y; z) { for (a; b; c) { u; } }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", AlwaysEdge)) + succOf("y") shouldBe expected(("a", TrueEdge), ("EXIT", FalseEdge)) + succOf("z") shouldBe expected(("y", AlwaysEdge)) + succOf("a") shouldBe expected(("b", AlwaysEdge)) + succOf("b") shouldBe expected(("u", TrueEdge), ("z", FalseEdge)) + succOf("c") shouldBe expected(("b", AlwaysEdge)) + succOf("u") shouldBe expected(("c", AlwaysEdge)) + } + + "be correct with empty condition" in + new Fixture("for (;;) { a = 1; }") { + succOf("ENTRY") shouldBe expected(("a", AlwaysEdge)) + succOf("a") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("a = 1", AlwaysEdge)) + succOf("a = 1") shouldBe expected(("a", AlwaysEdge)) + } + + "be correct with empty condition with break" in + new Fixture("for (;;) { break; }") { + succOf("ENTRY") shouldBe expected(("break;", AlwaysEdge)) + succOf("break;") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with empty condition with continue" in + new Fixture("for (;;) { continue ; }") { + succOf("ENTRY") shouldBe expected(("continue ;", AlwaysEdge)) + succOf("continue ;") shouldBe expected(("continue ;", AlwaysEdge)) + } + + "be correct with empty condition with nested empty for-loop" in + new Fixture("for (;;) { for (;;) { x; } }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("x", AlwaysEdge)) + } + + "be correct with empty condition with empty block" in + new Fixture("for (;;) ;") { + succOf("ENTRY") shouldBe expected() + } + + "be correct when empty for-loop is skipped" in + new Fixture("for (;;) {}; return;") { + succOf("ENTRY") shouldBe expected() + succOf("return;") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with function call condition with empty block" in + new Fixture("for (; x(1);) ;") { + succOf("ENTRY") shouldBe expected(("1", AlwaysEdge)) + succOf("1") shouldBe expected(("x(1)", AlwaysEdge)) + succOf("x(1)") shouldBe expected(("1", TrueEdge), ("EXIT", FalseEdge)) + } + } + + "Cfg for goto" should { + "be correct for single label" in + new Fixture("x; goto l1; y; l1:") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("goto l1;", AlwaysEdge)) + succOf("goto l1;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for multiple labels" in + new Fixture("x;goto l1; l2: y; l1:") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("goto l1;", AlwaysEdge)) + succOf("goto l1;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for multiple labels on same spot" in + new Fixture("x;goto l2;y;l1:l2:") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("goto l2;", AlwaysEdge)) + succOf("goto l2;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + } + + "Cfg for switch" should { + "be correct with one case" in + new Fixture("switch (x) { case 1: y; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("EXIT", CaseEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with multiple cases" in + new Fixture("switch (x) { case 1: y; case 2: z;}") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("z", CaseEdge), ("EXIT", CaseEdge)) + succOf("y") shouldBe expected(("z", AlwaysEdge)) + succOf("z") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with multiple cases on same spot" in + new Fixture("switch (x) { case 1: case 2: y; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("EXIT", CaseEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with multiple cases and multiple cases on same spot" in + new Fixture("switch (x) { case 1: case 2: y; case 3: z;}") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("z", CaseEdge), ("EXIT", CaseEdge)) + succOf("y") shouldBe expected(("z", AlwaysEdge)) + succOf("z") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with default case" in + new Fixture("switch (x) { default: y; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for case and default combined" in + new Fixture("switch (x) { case 1: y; break; default: z;}") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("z", CaseEdge)) + succOf("y") shouldBe expected(("break;", AlwaysEdge)) + succOf("break;") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("z") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct for nested switch" in + new Fixture("switch (x) { default: switch(y) { default: z; } }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge)) + succOf("y") shouldBe expected(("z", CaseEdge)) + succOf("z") shouldBe expected(("EXIT", AlwaysEdge)) + } + } + + "Cfg for if" should { + "be correct" in + new Fixture("if (x) { y; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", TrueEdge), ("EXIT", FalseEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with else block" in + new Fixture("if (x) { y; } else { z; }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", TrueEdge), ("z", FalseEdge)) + succOf("y") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("z") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with nested if" in + new Fixture("if (x) { if (y) { z; } }") { + succOf("ENTRY") shouldBe expected(("x", AlwaysEdge)) + succOf("x") shouldBe expected(("y", TrueEdge), ("EXIT", FalseEdge)) + succOf("y") shouldBe expected(("z", TrueEdge), ("EXIT", FalseEdge)) + succOf("z") shouldBe expected(("EXIT", AlwaysEdge)) + } + + "be correct with else if chain" in + new Fixture("if (a) { b; } else if (c) { d;} else { e; }") { + succOf("ENTRY") shouldBe expected(("a", AlwaysEdge)) + succOf("a") shouldBe expected(("b", TrueEdge), ("c", FalseEdge)) + succOf("b") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("c") shouldBe expected(("d", TrueEdge), ("e", FalseEdge)) + succOf("d") shouldBe expected(("EXIT", AlwaysEdge)) + succOf("e") shouldBe expected(("EXIT", AlwaysEdge)) + } + } +} diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPassTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPassTests.scala index 4d7f16d..cb14876 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPassTests.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CMetaDataPassTests.scala @@ -13,12 +13,12 @@ class CMetaDataPassTests extends WordSpec with Matchers { val cpg = Cpg.emptyCpg new CMetaDataPass(cpg).createAndApply() - "create exactly three nodes" in { - cpg.graph.V.asScala.size shouldBe 3 + "create exactly two nodes" in { + cpg.graph.V.asScala.size shouldBe 2 } - "create one edge" in { - cpg.graph.E.asScala.size shouldBe 1 + "create no edges" in { + cpg.graph.E.asScala.size shouldBe 0 } "create a metadata node with correct language" in { @@ -28,10 +28,5 @@ class CMetaDataPassTests extends WordSpec with Matchers { "create a '' NamespaceBlock" in { cpg.namespaceBlock.name.l shouldBe List(Defines.globalNamespaceName) } - - "connect '' with a file node" in { - cpg.namespaceBlock.name("").file.size shouldBe 1 - } - } } diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPassTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPassTests.scala index 9dfc561..3d17e1e 100644 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPassTests.scala +++ b/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/CfgCreationPassTests.scala @@ -2,23 +2,17 @@ package io.shiftleft.fuzzyc2cpg.passes import better.files.File import io.shiftleft.codepropertygraph.Cpg +import io.shiftleft.fuzzyc2cpg.adapter.{AlwaysEdge, CaseEdge, CfgEdgeType, FalseEdge, TrueEdge} import io.shiftleft.passes.IntervalKeyPool import org.scalatest.{Matchers, WordSpec} import io.shiftleft.semanticcpg.language._ import scala.jdk.CollectionConverters._ import io.shiftleft.codepropertygraph.generated.nodes -import io.shiftleft.fuzzyc2cpg.passes.cfgcreation.Cfg.{AlwaysEdge, CaseEdge, CfgEdgeType, FalseEdge, TrueEdge} class CfgCreationPassTests extends WordSpec with Matchers { "Cfg" should { - - "contain an entry and exit node at least" in new CfgFixture("") { - succOf("func ()") shouldBe expected(("RET", AlwaysEdge)) - succOf("RET") shouldBe expected() - } - "be correct for decl statement with assignment" in new CfgFixture("int x = 1;") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) @@ -190,14 +184,6 @@ class CfgCreationPassTests extends WordSpec with Matchers { succOf("y") shouldBe expected(("x", TrueEdge), ("z", FalseEdge)) succOf("z") shouldBe expected(("x", TrueEdge), ("RET", FalseEdge)) } - - "be correct for do-while-loop with empty body" in - new CfgFixture("do { } while(x > 1);") { - succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("1") shouldBe expected(("x > 1", AlwaysEdge)) - succOf("x > 1") shouldBe expected(("x", TrueEdge), ("RET", FalseEdge)) - } - } "Cfg for for-loop" should { @@ -293,12 +279,12 @@ class CfgCreationPassTests extends WordSpec with Matchers { "be correct with empty condition with empty block" in new CfgFixture("for (;;) ;") { - succOf("func ()") shouldBe expected(("RET", AlwaysEdge)) + succOf("func ()") shouldBe expected() } "be correct when empty for-loop is skipped" in new CfgFixture("for (;;) {}; return;") { - succOf("func ()") shouldBe expected(("return;", AlwaysEdge)) + succOf("func ()") shouldBe expected() succOf("return;") shouldBe expected(("RET", AlwaysEdge)) } @@ -315,28 +301,24 @@ class CfgCreationPassTests extends WordSpec with Matchers { new CfgFixture("x; goto l1; y; l1:") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) succOf("x") shouldBe expected(("goto l1;", AlwaysEdge)) - succOf("goto l1;") shouldBe expected(("l1:", AlwaysEdge)) - succOf("l1:") shouldBe expected(("RET", AlwaysEdge)) - succOf("y") shouldBe expected(("l1:", AlwaysEdge)) + succOf("goto l1;") shouldBe expected(("RET", AlwaysEdge)) + succOf("y") shouldBe expected(("RET", AlwaysEdge)) } "be correct for multiple labels" in new CfgFixture("x;goto l1; l2: y; l1:") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) succOf("x") shouldBe expected(("goto l1;", AlwaysEdge)) - succOf("goto l1;") shouldBe expected(("l1:", AlwaysEdge)) - succOf("y") shouldBe expected(("l1:", AlwaysEdge)) - succOf("l1:") shouldBe expected(("RET", AlwaysEdge)) + succOf("goto l1;") shouldBe expected(("RET", AlwaysEdge)) + succOf("y") shouldBe expected(("RET", AlwaysEdge)) } "be correct for multiple labels on same spot" in new CfgFixture("x;goto l2;y;l1:l2:") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) succOf("x") shouldBe expected(("goto l2;", AlwaysEdge)) - succOf("goto l2;") shouldBe expected(("l2:", AlwaysEdge)) - succOf("y") shouldBe expected(("l1:", AlwaysEdge)) - succOf("l1:") shouldBe expected(("l2:", AlwaysEdge)) - succOf("l2:") shouldBe expected(("RET", AlwaysEdge)) + succOf("goto l2;") shouldBe expected(("RET", AlwaysEdge)) + succOf("y") shouldBe expected(("RET", AlwaysEdge)) } } @@ -344,69 +326,54 @@ class CfgCreationPassTests extends WordSpec with Matchers { "be correct with one case" in new CfgFixture("switch (x) { case 1: y; }") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("y", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("RET", CaseEdge)) succOf("y") shouldBe expected(("RET", AlwaysEdge)) } "be correct with multiple cases" in new CfgFixture("switch (x) { case 1: y; case 2: z;}") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("case 2:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("z", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("z", CaseEdge), ("RET", CaseEdge)) + succOf("y") shouldBe expected(("z", AlwaysEdge)) succOf("z") shouldBe expected(("RET", AlwaysEdge)) } "be correct with multiple cases on same spot" in new CfgFixture("switch (x) { case 1: case 2: y; }") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("case 2:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("case 2:", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("RET", CaseEdge)) succOf("y") shouldBe expected(("RET", AlwaysEdge)) } "be correct with multiple cases and multiple cases on same spot" in new CfgFixture("switch (x) { case 1: case 2: y; case 3: z;}") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), - ("case 2:", CaseEdge), - ("case 3:", CaseEdge), - ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("case 2:", AlwaysEdge)) - succOf("case 2:") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("case 3:", AlwaysEdge)) - succOf("case 3:") shouldBe expected(("z", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("z", CaseEdge), ("RET", CaseEdge)) + succOf("y") shouldBe expected(("z", AlwaysEdge)) succOf("z") shouldBe expected(("RET", AlwaysEdge)) } "be correct with default case" in new CfgFixture("switch (x) { default: y; }") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("default:", CaseEdge)) - succOf("default:") shouldBe expected(("y", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge)) succOf("y") shouldBe expected(("RET", AlwaysEdge)) } "be correct for case and default combined" in new CfgFixture("switch (x) { case 1: y; break; default: z;}") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("default:", CaseEdge)) - succOf("case 1:") shouldBe expected(("y", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge), ("z", CaseEdge)) succOf("y") shouldBe expected(("break;", AlwaysEdge)) succOf("break;") shouldBe expected(("RET", AlwaysEdge)) - succOf("default:") shouldBe expected(("z", AlwaysEdge)) succOf("z") shouldBe expected(("RET", AlwaysEdge)) } "be correct for nested switch" in - new CfgFixture("switch (x) { case 1: switch(y) { default: z; } }") { + new CfgFixture("switch (x) { default: switch(y) { default: z; } }") { succOf("func ()") shouldBe expected(("x", AlwaysEdge)) - succOf("x") shouldBe expected(("case 1:", CaseEdge), ("RET", CaseEdge)) - succOf("case 1:") shouldBe expected(("y", AlwaysEdge)) - succOf("y") shouldBe expected(("default:", CaseEdge)) - succOf("default:") shouldBe expected(("z", AlwaysEdge)) + succOf("x") shouldBe expected(("y", CaseEdge)) + succOf("y") shouldBe expected(("z", CaseEdge)) succOf("z") shouldBe expected(("RET", AlwaysEdge)) } } @@ -469,7 +436,7 @@ class CfgFixture(file1Code: String) { def expected(pairs: (String, CfgEdgeType)*): Set[String] = { pairs.map { - case (code, _) => + case (code, cfgEdgeType) => codeToNode(code).start.code.head }.toSet } diff --git a/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/TypeNodePassTests.scala b/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/TypeNodePassTests.scala deleted file mode 100644 index 7efe57e..0000000 --- a/src/test/scala/io/shiftleft/fuzzyc2cpg/passes/TypeNodePassTests.scala +++ /dev/null @@ -1,34 +0,0 @@ -package io.shiftleft.fuzzyc2cpg.passes - -import better.files.File -import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.passes.IntervalKeyPool -import org.scalatest.{Matchers, WordSpec} -import io.shiftleft.semanticcpg.language._ -import scala.jdk.CollectionConverters._ - -class TypeNodePassTests extends WordSpec with Matchers { - "TypeNodePass" should { - "create TYPE nodes for used types" in TypeNodePassFixture("int main() { int x; }") { cpg => - cpg.typ.name.toSet shouldBe Set("int", "void") - } - } -} - -object TypeNodePassFixture { - def apply(file1Code: String)(f: Cpg => Unit): Unit = { - File.usingTemporaryDirectory("fuzzyctest") { dir => - val file1 = (dir / "file1.c") - file1.write(file1Code) - - val cpg = Cpg.emptyCpg - val keyPool = new IntervalKeyPool(1001, 2000) - val filenames = List(file1.path.toAbsolutePath.toString) - val astCreator = new AstCreationPass(filenames, cpg, keyPool) - astCreator.createAndApply() - new TypeNodePass(astCreator.global.usedTypes.keys().asScala.toList, cpg).createAndApply() - - f(cpg) - } - } -}