diff --git a/src/main/scala/PassimApp.scala b/src/main/scala/PassimApp.scala index 599ddd9..05a9ae8 100644 --- a/src/main/scala/PassimApp.scala +++ b/src/main/scala/PassimApp.scala @@ -85,6 +85,8 @@ case class PassAlign(id1: String, id2: String, case class AlignedStrings(s1: String, s2: String, matches: Int, score: Float) +case class AlignedStringsWithOffsets(s1: String, s2: String, b1: Int, b2: Int, e1: Int, e2: Int, matches: Int, score: Float) + case class LinkedSpan(span: Span, links: ArrayBuffer[Long]) case class ExtentPair(seq1: Int, seq2: Int, begin1: Int, begin2: Int, end1: Int, end2: Int, tok1: Int, tok2: Int) @@ -281,6 +283,91 @@ object PassFun { }).toSeq } } + + def recursivelySWAlignStrings(s1: String, s2: String, + n: Int, gap2: Int, matchMatrix: jaligner.matrix.Matrix, + openGap: Float, contGap: Float): Seq[AlignedPassage] = { + val m1 = hapaxIndex(n, s1) + val m2 = hapaxIndex(n, s2) + val inc = increasingMatches(m1 + .flatMap(z => if (m2.contains(z._1)) Some((z._2, m2(z._1), 1)) else None)) + val prod = s1.size * s2.size + if ( inc.size == 0 && (prod >= gap2 || prod < 0) ) { + Seq() + } else { + (Array((0, 0, 0)) ++ inc ++ Array((s1.size, s2.size, 0))) + .sliding(2).flatMap(z => { + val (b1, b2, c) = z(0) + val (e1, e2, d) = z(1) + val n1 = e1 - b1 + val n2 = e2 - b2 + val chartSize = n1 * n2 + if ( c == 0 ) { // if we're before the last matching n-gram + if ( d == 0 ) { + Seq() + } else { + val p1 = s1.substring(Math.max(0, e1 - 100), e1 + n) + val p2 = s2.substring(Math.max(0, e2 - 100), e2 + n) + val alg = jaligner.SmithWatermanGotoh.align(new jaligner.Sequence(p1), + new jaligner.Sequence(p2), matchMatrix, openGap, contGap) + val a1 = new String(alg.getSequence1()) + val a2 = new String(alg.getSequence2()) + val len1 = a1.length - a1.count(_ == '-') + val len2 = a2.length - a2.count(_ == '-') + // Check if SW alignment is anchored at the right edge. + if ( len1 > n && len2 > n + && alg.getStart1() + len1 >= p1.length && alg.getStart2() + len2 >= p2.length ) { + Seq(AlignedPassage(a1.substring(0, a1.length-n), + a2.substring(0, a2.length-n), + e1 - len1 + n, e2 - len2 + n, alg.getIdentity - n, alg.getScore)) + } else { + Seq() + } + } + } else if ( d == 0 ) { // if we're after the last matching n-gram + val p1 = s1.substring(b1, Math.min(b1 + 100, e1)) + val p2 = s2.substring(b2, Math.min(b2 + 100, e2)) + val alg = jaligner.SmithWatermanGotoh.align(new jaligner.Sequence(p1), + new jaligner.Sequence(p2), matchMatrix, openGap, contGap) + val a1 = new String(alg.getSequence1()) + val a2 = new String(alg.getSequence2()) + // Check if SW alignment is anchored at the left edge + if ( alg.getStart1() == 0 && alg.getStart2() == 0 ) { + Seq(AlignedPassage(a1, a2, b1, b2, alg.getIdentity, alg.getScore)) + } else { + Seq() + } + } else if ( chartSize <= gap2 && chartSize >= 0 ) { // overflow! + val p1 = s1.substring(b1, e1) + val p2 = s2.substring(b2, e2) + if ( n1 == n2 && p1 == p2 ) { + Seq(AlignedPassage(p1, p2, b1, b2, p1.size, 2.0f * p2.size)) + } else { + val alg = jaligner.NeedlemanWunschGotoh.align(new jaligner.Sequence(p1), + new jaligner.Sequence(p2), matchMatrix, openGap, contGap) + // // HACK!! WHY does JAligner swap sequences ?!?!?!? + val a1 = new String(alg.getSequence2) + val a2 = new String(alg.getSequence1) + if ( a1.replaceAll("-", "") == p2 && a2.replaceAll("-", "") == p1 ) { + Seq(AlignedPassage(a2, a1, b1, b2, alg.getIdentity, alg.getScore)) + } else { + Seq(AlignedPassage(a1, a2, b1, b2, alg.getIdentity, alg.getScore)) + } + } + } else { + if ( c > 0 ) { + val len = Math.min(n, Math.min(n1, n2)) + val p1 = s1.substring(b1, b1 + len) + val p2 = s2.substring(b2, b2 + len) + Array(AlignedPassage(p1, p2, b1, b2, len, 2.0f * len)) ++ + recursivelyAlignStrings(s1.substring(b1 + len, e1), s2.substring(b2 + len, e2), n, gap2, matchMatrix, openGap, contGap) + } else { + recursivelyAlignStrings(s1.substring(b1, e1), s2.substring(b2, e2), n, gap2, matchMatrix, openGap, contGap) + } + } + }).toSeq + } + } } case class TokText(terms: Array[String], termCharBegin: Array[Int], termCharEnd: Array[Int]) @@ -560,8 +647,44 @@ transform($pageCol, (if ( s2 != null ) s2.replaceAll("-", "\u2010") else ""), config.n, config.gap * config.gap, matchMatrix, openGap, contGap) - AlignedStrings(chunks.map(_.s1).mkString, chunks.map(_.s2).mkString, + + AlignedStrings(chunks.map(_.s1).mkString, chunks.map(_.s2).mkString, + chunks.map(_.matches).sum, chunks.map(_.score).sum) + } + } + + def makeSWAligner(config: Config, + matchScore: Float = 2, mismatchScore: Float = -1, + openGap: Float = 5.0f, contGap: Float = 0.5f) = { + val matchMatrix = jaligner.matrix.MatrixGenerator.generate(matchScore, mismatchScore) + udf { (s1: String, s2: String) => + val chunks = PassFun.recursivelySWAlignStrings( + (if ( s1 != null ) s1.replaceAll("-", "\u2010") else ""), + (if ( s2 != null ) s2.replaceAll("-", "\u2010") else ""), + (config.n*config.wordLength).toInt, config.gap * config.gap, + matchMatrix, openGap, contGap) + //multiply n by average word length to use longer character ngrams as anchor points for alignment + + //compute the begin and endpoints of the alignment (in characters) + // sometimes the alignment returns an empty sequence, so we need to check that hasn't + // hapened before trying to calculate endpoints. + + if (chunks.length > 0) { + val firstChunk = chunks(0) + val b1 = firstChunk.b1 + val b2 = firstChunk.b2 + val e1 = b1 + chunks.map(_.s1).mkString.replace("-","").length() + val e2 = b2 + chunks.map(_.s2).mkString.replace("-","").length() + AlignedStringsWithOffsets(chunks.map(_.s1).mkString, chunks.map(_.s2).mkString, b1, b2, e1, e2, + chunks.map(_.matches).sum, chunks.map(_.score).sum) + } else { + val b1 = 0 + val b2 = 0 + val e1 = 0 + val e2 = 0 + AlignedStringsWithOffsets(chunks.map(_.s1).mkString, chunks.map(_.s2).mkString, b1, b2, e1, e2, chunks.map(_.matches).sum, chunks.map(_.score).sum) + } } } @@ -811,7 +934,6 @@ transform($pageCol, } def aggregateAlignments(config: Config, corpus: DataFrame, extents: DataFrame): DataFrame = { import align.sparkSession.implicits._ - val alignStrings = makeStringAligner(config) val neededCols = corpus.select("uid", "seq", config.id, config.group) var texts = corpus.select(config.group, "seq","text").withColumnRenamed(config.group,"t_series").withColumnRenamed("seq","t_seq") @@ -880,7 +1002,7 @@ transform($pageCol, //we will now deduplicate the aggregated docuemnts val dedupedlicatedDocs = allDocs.groupBy("seriesStr","seqsList") .agg(first("seriesStr") as "series",first("seqsList") as "seqs", - collect_list("pairID") as "pairIDs") + collect_list("pairID") as "pairIDs") .select("series","seqs","pairIDs") .withColumn("docID",makeId(col("series"),col("seqs"))) @@ -891,7 +1013,7 @@ transform($pageCol, .drop("t_seq","t_series") .groupBy("series","docID") .agg(collect_list("text") as "text", - collect_list("seq") as "seqs") + collect_list("seq") as "seqs") .select("docID","seqs","text") //we will finally add the texts to the dataframe of aggregate documents @@ -920,9 +1042,6 @@ transform($pageCol, var castSpans = spans.map(PassimApp.rowToExtentPair) //now we will aggregate the seq pairs into spans of adjacent extents - // - // outermost list - // one entry per group of consecutive chunks between the pair of texts var aggregatedPairs = Array[Array[Array[Int]]]() var usedInAggregate = Array[Int]() //for each pair in the set of extent pairs... @@ -963,13 +1082,11 @@ transform($pageCol, var lengthFromBegin1 = nextSpan.begin1 var lengthFromBegin2 = nextSpan.begin2 - //if the next pair is not adjacent to the current one in both books, we're done - // we're also done if the current span pair ends more than tokens from - // the the boundary in both texts or the next one starts more than tokens from the boundary in both texts - if ((nextPos1 >= pos1+1 && nextPos2 > pos2 + 1) || (nextPos2 < pos2) || - (((lengthToEnd2>gap) && (lengthToEnd1>gap)) || - ((lengthFromBegin1>gap) && (lengthFromBegin2>gap)))) { + // that second clause might not actually be needed? It might be actively causing problems. + // investigate if we notice odd choices of what to combine or not combine when aggregating + // at the chunk level + if ((nextPos1 > pos1+1 && nextPos2 > pos2 + 1) || (nextPos2 < pos2)) { finished = true //check if we should look in the next chunk after this one in either text if (lengthToEnd1 0) && (currentPair(1).last == otherPair(1)(0)-1)) || + //abut in text 1 + ((overlap(1).length > 0) && (currentPair(0).last == otherPair(0)(0)-1)) || + //overlap in both + ((overlap(0).length > 0) && (overlap(1).length > 0))) { + //if there is overlap and the missing spans are later (or earlier) in both texts, + // merge the two spans by adding the missing seq numbers to the span we're merging into + var missingFrom1 = otherPair(0).diff(currentPair(0)).sorted + var missingFrom2 = otherPair(1).diff(currentPair(1)).sorted + + if ((missingFrom1.length==0) || (missingFrom2.length==0) || + //both missing sequences are from before the aligned sections + ((missingFrom1.last < currentPair(0)(0)) && (missingFrom2.last < currentPair(1)(0))) || + //both missing sequences are from after + ((missingFrom1(0) > currentPair(0).last) && (missingFrom2(0) > currentPair(1).last))) { + + currentPair(0) = (currentPair(0) ++ missingFrom1).sorted + currentPair(1) = (currentPair(1) ++ missingFrom2).sorted + usedInMerge :+= j + } + + j += 1 + } else { + j += 1 } - aggregatedPairs :+= aggregate } + aggregatedPairsMerged :+= currentPair } + + + } - aggregatedPairs + + aggregatedPairsMerged } val makeId = udf { (series: String,seqs: Seq[Int]) => (series+"_"+seqs(0).toString+"-"+seqs.last.toString) } val hashId = udf { (id: String) => hashString(id) } @@ -1046,6 +1214,53 @@ transform($pageCol, val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) fs.exists(qualified) } + //calculates the offset for each seq in a series to determine the series level starting point + // of each seq + val calculateOffsets = udf { (docs: Seq[Row]) => + var offsets = Array(0) + for (index <- docs.indices) { + if (index>0) { + offsets = offsets :+ (offsets.last + (docs(index-1)(2).asInstanceOf[Int])) + } + } + offsets + } + //does the same, but at the token level + val calculateTokenOffsets = udf { (docs: Seq[Row]) => + var offsets = Array(0) + for (index <- docs.indices) { + if (index>0) { + offsets = offsets :+ (offsets.last + (docs(index-1)(3).asInstanceOf[Int])) + } + } + offsets + } + //collects all the texts for a given series. This could probably be done more simply with map + val concatStrings = udf { (texts: Seq[Row]) => + var allTexts = "" + for (index <- texts.indices) { + allTexts = allTexts.concat(texts(index)(1).asInstanceOf[String]) + } + allTexts + } + //calculates the offset for each seq in a series to determine the series level starting point + // of each seq + val getSeqs = udf { (docs: Seq[Row]) => + var seqs = Array[Int]() + for (index <- docs.indices) { + seqs = seqs :+ (docs(index)(0).asInstanceOf[Int]) + } + seqs + } + + //given a aligned sequence and an end index, this function counts the number of spaces prior + // to that end point, ignoring gaps in the alignment. + // This is used to get a quick estimate of the start token index of an alignment. + // It migth be wrong, but it will be wrong consistently so it should work out alright. + //It's better than trying to reverse engineer the tokenization + val countSpaces = udf { (text: String,endpoint: Int) => + text.replace("-","").slice(0,endpoint).count(_ == ' ') + } def main(args: Array[String]) { val conf = new SparkConf() @@ -1270,18 +1485,93 @@ transform($pageCol, .save(config.outputPath + "/context." + config.outputFormat) } - if (config.pairwise || config.aggregate) { + if ( config.pairwise ) { val alignments = extents.pairwiseAlignments(config, corpus) - if ( config.pairwise ) { - alignments.write.mode("ignore").format(config.outputFormat) - .save(config.outputPath + "/align." + config.outputFormat) - } + alignments.write.format(config.outputFormat) + .save(config.outputPath + "/align." + config.outputFormat) + } - if ( config.aggregate ) { - extents.aggregateAlignments(config, corpus, extents) - .write.mode("ignore").format(config.outputFormat) - .save(config.outputPath + "/aggregate." + config.outputFormat) - } + if ( config.aggregate ) { + extents.aggregateAlignments(config, corpus, extents) + .write.format(config.outputFormat) + .save(config.outputPath + "/aggregate." + config.outputFormat) + + val aggregate = spark.read.format(config.outputFormat).load(config.outputPath + "/aggregate." + config.outputFormat) + .withColumn("pairID", explode('pairIDs)).drop("pairIDs") + + //collect the documents for a given series, so we can compute the series-level offsets of each + // document. Also fix some type weirdness. Seq should be an int not a long + val seriesLevelCorpus = corpus.withColumn("len",length(col("text"))).withColumn("seq",col("seq").cast(IntegerType)) + .withColumn("tok",size(col("terms")).cast(IntegerType)) + .groupBy("series").agg(sort_array(collect_list(struct("seq", "text", "len","tok"))).alias("orderedTexts")) + .withColumn("seqOffsets",calculateOffsets(col("orderedTexts"))) + .withColumn("tokOffsets",calculateTokenOffsets(col("orderedTexts"))) + .withColumn("seqs",getSeqs(col("orderedTexts"))) + .withColumn("text",concatStrings(col("orderedTexts"))) + .drop("orderedTexts") + + //corpus.withColumn("tok",size(col("terms")).cast(IntegerType)).select("id","series","tok","seq","text").orderBy("series","seq").write.json(config.outputPath+"/corpus.json") + //seriesLevelCorpus.write.json(config.outputPath+"/aggSeries.json") + + //create a dataframe of (series,seq,offset) triples so the endpoints of alignments + // can be properly set + //the column renaming simplifies the joins, since it makes the column names unambiguous. + // there's probably a better way to do things, but I do not know it + val seqOffsets = seriesLevelCorpus.select("series","seqs","seqOffsets","tokOffsets") + .withColumn("offsetsWithSeq",arrays_zip(col("seqs"),col("seqOffsets"),col("tokOffsets"))) + .drop("seqs","seqOffsets","tokOffsets") + .withColumn("offsetWithSeq",explode(col("offsetsWithSeq"))) + .withColumn("offset",col("offsetWithSeq.seqOffsets")) + .withColumn("seq",col("offsetWithSeq.seqs")) + .withColumn("tok",col("offsetWithSeq.tokOffsets")) + .select("series","seq","offset","tok") + .withColumnRenamed("series","_series").withColumnRenamed("seq","_seq").withColumnRenamed("tok","_tok") + + val fields = aggregate.columns.filter { _ != "pairID" }.map(expr) + + val swAligner = makeSWAligner(config) + aggregate.select('pairID, struct(fields:_*) as "info1") + .join(aggregate.select('pairID, struct(fields:_*) as "info2"), "pairID") + .filter($"info1.id" < $"info2.id") + .withColumn("swalg", swAligner($"info1.text", $"info2.text")) + //how do you extract the columns succinctly????? + .select(col("info1.text").alias("text1"),col("info2.text").alias("text2"),col("info1.id").alias("id1"), col("info1.seqs").alias("seqs1"), col("info1.series").alias("series1"), + col("info2.id").alias("id2"), col("info2.seqs").alias("seqs2"), col("info2.series").alias("series2"), + col("swalg.b1").alias("b1"), col("swalg.b2").alias("b2"), col("swalg.e1").alias("e1"), + col("swalg.e2").alias("e2"), col("swalg.matches").alias("matches"), + col("swalg.s1").alias("s1"), col("swalg.s2").alias("s2"), col("swalg.score").alias("score")) + //add the series level character offsets and set the id to that of + // the book the text is taken from + .withColumn("beginSeq1",element_at(col("seqs1"),1)) + .withColumn("beginSeq2",element_at(col("seqs2"),1)) + //add the offsets to the start points to get the offsets into the entire book + .join(seqOffsets,(col("series1")===seqOffsets("_series")) && (col("beginSeq1")===seqOffsets("_seq"))) + .withColumnRenamed("offset","offset1").withColumnRenamed("_tok","tokOffset1").drop("_series","_seq") + .join(seqOffsets,(col("series2")===seqOffsets("_series")) && (col("beginSeq2")===seqOffsets("_seq"))) + .withColumnRenamed("offset","offset2").withColumnRenamed("_tok","tokOffset2").drop("_series","_seq") + .withColumn("bw1",countSpaces(col("text1"),col("b1"))+col("tokOffset1")) + .withColumn("bw2",countSpaces(col("text2"),col("b2"))+col("tokOffset2")) + .withColumn("ew1",countSpaces(col("text1"),col("e1"))+col("tokOffset1")) + .withColumn("ew2",countSpaces(col("text2"),col("e2"))+col("tokOffset2")) + .withColumn("len1",length(col("text1"))) + .withColumn("len2",length(col("text2"))) + .withColumn("bDoc1",col("b1")).withColumn("bDoc2",col("b2")) + .withColumn("eDoc1",col("e1")).withColumn("eDoc2",col("e2")) + .withColumn("b1",col("b1")+col("offset1")).withColumn("b2",col("b2")+col("offset2")) + .withColumn("e1",col("e1")+col("offset1")).withColumn("e2",col("e2")+col("offset2")) + .drop("text1","text2") + //update the ids to just be the book ids + .withColumn("idDoc1",col("id1")) + .withColumn("idDoc2",col("id2")) + .withColumn("id1",col("series1")) + .withColumn("id2",col("series2")) + //remove unneeded fields + .drop("beginSeq1","beginSeq2","offset1","offset2") + .write.format(config.outputFormat) + .save(config.outputPath + "/aggregateAlignments." + config.outputFormat) + + + sys.exit(0) } if ( config.boilerplate || config.docwise || config.linewise) { @@ -1371,53 +1661,5 @@ transform($pageCol, .write.mode("ignore").format(config.outputFormat).save(outFname) spark.stop() - - //if the aggregate option was passed in, we must now call main on the aggregated documents - // inheriting the config from the current run - if ( config.aggregate ) { - var command = "" - - //tildes are used to divide arguments, as there's a comma in the --filterpairs argument - command = command.concat("-n~".concat(config.n.toString)) - command = command.concat("~-l~".concat(config.minDF.toString)) - command = command.concat("~-u~".concat(config.maxDF.toString)) - command = command.concat("~-m~".concat(config.minRep.toString)) - command = command.concat("~-a~".concat(config.minAlg.toString)) - command = command.concat("~-c~".concat(config.context.toString)) - command = command.concat("~-o~".concat(config.relOver.toString)) - command = command.concat("~-M~".concat(config.mergeDiverge.toString)) - command = command.concat("~-r~".concat(config.maxRep.toString)) - command = command.concat("~-g~".concat(config.gap.toString)) - command = command.concat("~-i~".concat("id")) - command = command.concat("~-t~".concat("text")) - command = command.concat("~-s~".concat("series")) - - if (config.pairwise) { - command = command.concat("~--pairwise") - } - - if (config.docwise) { - command = command.concat("~--docwise") - } - - if (config.fields != "") { - command = command.concat("~--fields~".concat(config.fields.concat(";pairIDs"))) - } else { - command = command.concat("~--fields~".concat(config.fields.concat("pairIDs"))) - } - - //figure out how to restrict our search to doc pairs that share a pairID - command = command.concat("~--filterpairs~".concat(config.filterpairs.concat(" AND (size(array_intersect(pairIDs,pairIDs2))>0)"))) - - command = command.concat("~--input-format~".concat(config.outputFormat)) - command = command.concat("~--output-format~".concat(config.outputFormat)) - command = command.concat("~-w~".concat(config.wordLength.toString)) - - //add the input and output paths - command = command.concat("~"+config.outputPath+"/aggregate." + config.outputFormat) - command = command.concat("~"+config.outputPath+"/aggregateAlignments") - - main(command.split("~")) - } } }