Skip to content

Commit

Permalink
Implement CheckResizeSegmentsRequest for UnicodeRewriter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712807064
  • Loading branch information
hiroyuki-komatsu committed Jan 7, 2025
1 parent 17e5145 commit 91729ed
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 78 deletions.
15 changes: 15 additions & 0 deletions src/converter/converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2354,4 +2354,19 @@ TEST_F(ConverterTest, IntegrationWithSymbolRewriter) {
}
}

TEST_F(ConverterTest, IntegrationWithUnicodeRewriter) {
std::unique_ptr<EngineInterface> engine =
MockDataEngineFactory::Create().value();
ConverterInterface *converter = engine->GetConverter();

{
Segments segments;
const ConversionRequest convreq =
ConversionRequestBuilder().SetKey("U+3042").Build();
ASSERT_TRUE(converter->StartConversion(convreq, &segments));
EXPECT_EQ(segments.conversion_segments_size(), 1);
EXPECT_TRUE(FindCandidateByValue("", segments.conversion_segment(0)));
}
}

} // namespace mozc
1 change: 1 addition & 0 deletions src/rewriter/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ mozc_cc_test(
timeout = "moderate",
srcs = ["unicode_rewriter_test.cc"],
deps = [
":rewriter_interface",
":unicode_rewriter",
"//composer",
"//converter:segments",
Expand Down
2 changes: 1 addition & 1 deletion src/rewriter/rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Rewriter::Rewriter(const engine::Modules &modules,
AddRewriter(EmoticonRewriter::CreateFromDataManager(*data_manager));
AddRewriter(std::make_unique<CalculatorRewriter>());
AddRewriter(std::make_unique<SymbolRewriter>(data_manager));
AddRewriter(std::make_unique<UnicodeRewriter>(&parent_converter));
AddRewriter(std::make_unique<UnicodeRewriter>());
AddRewriter(std::make_unique<VariantsRewriter>(pos_matcher));
AddRewriter(std::make_unique<ZipcodeRewriter>(pos_matcher));
AddRewriter(std::make_unique<DiceRewriter>());
Expand Down
78 changes: 52 additions & 26 deletions src/rewriter/unicode_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

#include <cstddef>
#include <cstdint>
#include <limits>
#include <optional>
#include <string>
#include <utility>

Expand All @@ -47,6 +49,7 @@
#include "converter/converter_interface.h"
#include "converter/segments.h"
#include "request/conversion_request.h"
#include "rewriter/rewriter_interface.h"

namespace mozc {
namespace {
Expand Down Expand Up @@ -122,50 +125,73 @@ bool UnicodeRewriter::RewriteToUnicodeCharFormat(
return true;
}

// If the key is in the "U+xxxx" format, the corresponding Unicode
// character is added. (ex. "U+0041" -> "A").
bool UnicodeRewriter::RewriteFromUnicodeCharFormat(
const ConversionRequest &request, Segments *segments) const {
std::string key;
for (const Segment &segment : segments->conversion_segments()) {
key += segment.key();
}

namespace {
std::optional<std::string> GetValue(absl::string_view key) {
if (!IsValidCodepointExpression(key)) {
return false;
return std::nullopt;
}

uint32_t codepoint = 0;
if (!UCS4ExpressionToInteger(key, &codepoint)) {
return false;
return std::nullopt;
}

if (!Util::IsAcceptableCharacterAsCandidate(codepoint)) {
return false;
return std::nullopt;
}

const std::string value = Util::CodepointToUtf8(codepoint);
if (value.empty()) {
return false;
return std::nullopt;
}

if (segments->conversion_segments_size() > 1) {
if (segments->resized()) {
// The given segments are resized by user so don't modify anymore.
return false;
}
return value;
}
} // namespace

std::optional<RewriterInterface::ResizeSegmentsRequest>
UnicodeRewriter::CheckResizeSegmentsRequest(const ConversionRequest &request,
const Segments &segments) const {
if (segments.resized() || segments.conversion_segments_size() <= 1) {
// The given segments are already resized.
return std::nullopt;
}

const uint32_t resize_len =
Util::CharsLen(key) -
Util::CharsLen(segments->conversion_segment(0).key());
if (!parent_converter_->ResizeSegment(segments, request, 0, resize_len)) {
return false;
}
absl::string_view key = request.key();
const size_t key_len = Util::CharsLen(key);
if (key_len > std::numeric_limits<uint8_t>::max()) {
return std::nullopt;
}
const uint8_t segment_size = static_cast<uint8_t>(key_len);

std::optional<std::string> value = GetValue(key);
if (!value.has_value()) {
return std::nullopt;
}

ResizeSegmentsRequest resize_request = {
.segment_index = 0,
.segment_sizes = {segment_size, 0, 0, 0, 0, 0, 0, 0},
};
return resize_request;
}

// If the key is in the "U+xxxx" format, the corresponding Unicode
// character is added. (ex. "U+0041" -> "A").
bool UnicodeRewriter::RewriteFromUnicodeCharFormat(
const ConversionRequest &request, Segments *segments) const {
if (segments->conversion_segments_size() != 1) {
return false;
}

absl::string_view key = request.key();
std::optional<std::string> value = GetValue(key);
if (!value.has_value()) {
return false;
}
DCHECK_EQ(1, segments->conversion_segments_size());

Segment *segment = segments->mutable_conversion_segment(0);
AddCandidate(std::move(key), std::move(value), 0, segment);
AddCandidate(std::string(key), std::move(value.value()), 0, segment);
return true;
}

Expand Down
11 changes: 4 additions & 7 deletions src/rewriter/unicode_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
#ifndef MOZC_REWRITER_UNICODE_REWRITER_H_
#define MOZC_REWRITER_UNICODE_REWRITER_H_

#include "absl/log/check.h"
#include "converter/converter_interface.h"
#include <optional>
#include "converter/segments.h"
#include "request/conversion_request.h"
#include "rewriter/rewriter_interface.h"
Expand All @@ -40,10 +39,9 @@ namespace mozc {

class UnicodeRewriter : public RewriterInterface {
public:
explicit UnicodeRewriter(const ConverterInterface *parent_converter)
: parent_converter_(parent_converter) {
DCHECK(parent_converter_);
}
std::optional<RewriterInterface::ResizeSegmentsRequest>
CheckResizeSegmentsRequest(const ConversionRequest &request,
const Segments &segments) const override;

bool Rewrite(const ConversionRequest &request,
Segments *segments) const override;
Expand All @@ -53,7 +51,6 @@ class UnicodeRewriter : public RewriterInterface {
Segments *segments) const;
bool RewriteFromUnicodeCharFormat(const ConversionRequest &request,
Segments *segments) const;
const ConverterInterface *parent_converter_;
};

} // namespace mozc
Expand Down
160 changes: 116 additions & 44 deletions src/rewriter/unicode_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <cstdlib>
#include <iterator>
#include <memory>
#include <optional>
#include <string>

#include "absl/strings/str_format.h"
Expand All @@ -45,6 +46,7 @@
#include "protocol/commands.pb.h"
#include "protocol/config.pb.h"
#include "request/conversion_request.h"
#include "rewriter/rewriter_interface.h"
#include "testing/gunit.h"
#include "testing/mozctest.h"

Expand Down Expand Up @@ -95,8 +97,7 @@ class UnicodeRewriterTest : public testing::TestWithTempUserProfile {

TEST_F(UnicodeRewriterTest, UnicodeConversionTest) {
Segments segments;
UnicodeRewriter rewriter(engine_->GetConverter());
const ConversionRequest request;
UnicodeRewriter rewriter;

struct UCS4UTF8Data {
absl::string_view codepoint;
Expand Down Expand Up @@ -163,6 +164,8 @@ TEST_F(UnicodeRewriterTest, UnicodeConversionTest) {
for (uint32_t ascii = 0x20; ascii < 0x7F; ++ascii) {
const std::string codepoint = absl::StrFormat("U+00%02X", ascii);
InitSegments(codepoint, codepoint, &segments);
const ConversionRequest request =
ConversionRequestBuilder().SetKey(codepoint).Build();
EXPECT_TRUE(rewriter.Rewrite(request, &segments));
EXPECT_EQ(segments.segment(0).candidate(0).value.at(0), ascii);
EXPECT_TRUE(segments.segment(0).candidate(0).attributes &
Expand All @@ -171,8 +174,10 @@ TEST_F(UnicodeRewriterTest, UnicodeConversionTest) {

// Mozc accepts Japanese characters
for (size_t i = 0; i < std::size(kCodepointUtf8Data); ++i) {
InitSegments(kCodepointUtf8Data[i].codepoint,
kCodepointUtf8Data[i].codepoint, &segments);
absl::string_view codepoint = kCodepointUtf8Data[i].codepoint;
InitSegments(codepoint, codepoint, &segments);
const ConversionRequest request =
ConversionRequestBuilder().SetKey(codepoint).Build();
EXPECT_TRUE(rewriter.Rewrite(request, &segments));
EXPECT_TRUE(ContainCandidate(segments, kCodepointUtf8Data[i].utf8));
EXPECT_TRUE(segments.segment(0).candidate(0).attributes &
Expand All @@ -182,58 +187,125 @@ TEST_F(UnicodeRewriterTest, UnicodeConversionTest) {
// Mozc does not accept other characters
for (size_t i = 0; i < std::size(kMozcUnsupportedUtf8); ++i) {
InitSegments(kMozcUnsupportedUtf8[i], kMozcUnsupportedUtf8[i], &segments);
const ConversionRequest request =
ConversionRequestBuilder().SetKey(kMozcUnsupportedUtf8[i]).Build();
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
}

// Invalid style input
InitSegments("U+123456789ABCDEF0", "U+123456789ABCDEF0", &segments);
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
absl::string_view invalid_key1 = "U+123456789ABCDEF0";
InitSegments(invalid_key1, invalid_key1, &segments);
const ConversionRequest request1 =
ConversionRequestBuilder().SetKey(invalid_key1).Build();
EXPECT_FALSE(rewriter.Rewrite(request1, &segments));

absl::string_view invalid_key2 = "U+12345678";
InitSegments(invalid_key2, invalid_key2, &segments);
const ConversionRequest request2 =
ConversionRequestBuilder().SetKey(invalid_key2).Build();
EXPECT_FALSE(rewriter.Rewrite(request2, &segments));

absl::string_view invalid_key3 = "U+XYZ";
InitSegments(invalid_key3, invalid_key3, &segments);
const ConversionRequest request3 =
ConversionRequestBuilder().SetKey(invalid_key3).Build();
EXPECT_FALSE(rewriter.Rewrite(request3, &segments));

absl::string_view invalid_key4 = "12345";
InitSegments(invalid_key4, invalid_key4, &segments);
const ConversionRequest request4 =
ConversionRequestBuilder().SetKey(invalid_key4).Build();
EXPECT_FALSE(rewriter.Rewrite(request4, &segments));

absl::string_view invalid_key5 = "U12345";
InitSegments(invalid_key5, invalid_key5, &segments);
const ConversionRequest request5 =
ConversionRequestBuilder().SetKey(invalid_key5).Build();
EXPECT_FALSE(rewriter.Rewrite(request5, &segments));
}

InitSegments("U+1234567", "U+12345678", &segments);
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
TEST_F(UnicodeRewriterTest, MultipleSegment) {
UnicodeRewriter rewriter;

InitSegments("U+XYZ", "U+XYZ", &segments);
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
{
// Multiple segments to be combined.
Segments segments;
InitSegments("U+0", "U+0", &segments);
AddSegment("02", "02", &segments);
AddSegment("0", "0", &segments);
const ConversionRequest request =
ConversionRequestBuilder().SetKey("U+0020").Build();
std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
rewriter.CheckResizeSegmentsRequest(request, segments);
EXPECT_TRUE(resize_request.has_value());
EXPECT_EQ(resize_request->segment_index, 0);
EXPECT_EQ(resize_request->segment_sizes[0], 6);
EXPECT_EQ(resize_request->segment_sizes[1], 0);
}

InitSegments("12345", "12345", &segments);
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
{
// The segments is already resized.
Segments segments;
InitSegments("U+0", "U+0", &segments);
AddSegment("02", "02", &segments);
AddSegment("0", "0", &segments);
segments.set_resized(true);
const ConversionRequest request =
ConversionRequestBuilder().SetKey("U+0020").Build();
std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
rewriter.CheckResizeSegmentsRequest(request, segments);
EXPECT_FALSE(resize_request.has_value());
}

InitSegments("U12345", "U12345", &segments);
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
}
{
// The size of segments is one.
Segments segments;
InitSegments("U+0020", "U+0020", &segments);
const ConversionRequest request =
ConversionRequestBuilder().SetKey("U+0020").Build();
std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
rewriter.CheckResizeSegmentsRequest(request, segments);
EXPECT_FALSE(resize_request.has_value());
EXPECT_TRUE(rewriter.Rewrite(request, &segments));
EXPECT_EQ(segments.conversion_segment(0).candidate(0).value.at(0), ' ');
}

TEST_F(UnicodeRewriterTest, MultipleSegment) {
Segments segments;
UnicodeRewriter rewriter(engine_->GetConverter());
const ConversionRequest request;

// Multiple segments are combined.
InitSegments("U+0", "U+0", &segments);
AddSegment("02", "02", &segments);
AddSegment("0", "0", &segments);
EXPECT_TRUE(rewriter.Rewrite(request, &segments));
EXPECT_EQ(segments.conversion_segments_size(), 1);
EXPECT_EQ(segments.conversion_segment(0).candidate(0).value.at(0), ' ');

// If the segments is already resized, returns false.
InitSegments("U+0020", "U+0020", &segments);
AddSegment("U+0020", "U+0020", &segments);
segments.set_resized(true);
EXPECT_FALSE(rewriter.Rewrite(request, &segments));

// History segment has to be ignored.
// In this case 1st segment is HISTORY
// so this rewriting returns true.
InitSegments("U+0020", "U+0020", &segments);
AddSegment("U+0020", "U+0020", &segments);
segments.set_resized(true);
segments.mutable_segment(0)->set_segment_type(Segment::HISTORY);
EXPECT_TRUE(rewriter.Rewrite(request, &segments));
EXPECT_EQ(segments.conversion_segment(0).candidate(0).value.at(0), ' ');
{
// History segment has to be ignored.
Segments segments;
InitSegments("U+0", "U+0", &segments);
AddSegment("02", "02", &segments);
AddSegment("0", "0", &segments);
segments.mutable_segment(0)->set_segment_type(Segment::HISTORY);
const ConversionRequest request =
ConversionRequestBuilder().SetKey("020").Build();
std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
rewriter.CheckResizeSegmentsRequest(request, segments);
EXPECT_FALSE(resize_request.has_value());
EXPECT_FALSE(rewriter.Rewrite(request, &segments));
}

{
// History segment has to be ignored.
// In this case 1st segment is HISTORY
// so this rewriting returns true.
Segments segments;
InitSegments("U+0020", "U+0020", &segments);
AddSegment("U+0020", "U+0020", &segments);
segments.set_resized(true);
segments.mutable_segment(0)->set_segment_type(Segment::HISTORY);
const ConversionRequest request =
ConversionRequestBuilder().SetKey("U+0020").Build();
std::optional<RewriterInterface::ResizeSegmentsRequest> resize_request =
rewriter.CheckResizeSegmentsRequest(request, segments);
EXPECT_FALSE(resize_request.has_value());
EXPECT_TRUE(rewriter.Rewrite(request, &segments));
EXPECT_EQ(segments.conversion_segment(0).candidate(0).value.at(0), ' ');
}
}

TEST_F(UnicodeRewriterTest, RewriteToUnicodeCharFormat) {
UnicodeRewriter rewriter(engine_->GetConverter());
UnicodeRewriter rewriter;
{ // Typical case
composer::Composer composer(nullptr, &default_request(), &default_config());
composer.set_source_text("A");
Expand Down

0 comments on commit 91729ed

Please sign in to comment.