diff --git a/backend/epimetheus/.gitignore b/backend/epimetheus/.gitignore index c0ddbeff..e996f8d0 100644 --- a/backend/epimetheus/.gitignore +++ b/backend/epimetheus/.gitignore @@ -2,6 +2,7 @@ HELP.md .gradle build/ !gradle/wrapper/gradle-wrapper.jar +!gradle/wrapper/gradle-wrapper.properties !**/src/main/**/build/ !**/src/test/**/build/ @@ -37,7 +38,6 @@ out/ .vscode/ ### properties ### -*.properties *.yml *.yaml -src/main/resources/*.yml +src/main/resources/*.yml \ No newline at end of file diff --git a/backend/epimetheus/build.gradle b/backend/epimetheus/build.gradle index 36d6c709..99d85a54 100644 --- a/backend/epimetheus/build.gradle +++ b/backend/epimetheus/build.gradle @@ -23,25 +23,18 @@ repositories { dependencies { implementation 'org.springframework.boot:spring-boot-starter-data-mongodb' -// implementation 'org.springframework.boot:spring-boot-starter-data-redis' -// implementation 'org.springframework.boot:spring-boot-starter-jdbc' implementation 'org.springframework.boot:spring-boot-starter-websocket' implementation 'org.springframework.boot:spring-boot-starter-webflux' implementation 'org.springframework.boot:spring-boot-starter-web' -// implementation 'org.mybatis.spring.boot:mybatis-spring-boot-starter:3.0.2' -// implementation 'org.springframework.session:spring-session-data-redis' -// implementation 'org.springframework.session:spring-session-jdbc' compileOnly 'org.projectlombok:lombok' developmentOnly 'org.springframework.boot:spring-boot-devtools' -// developmentOnly 'org.springframework.boot:spring-boot-docker-compose' runtimeOnly 'io.micrometer:micrometer-registry-prometheus' -// runtimeOnly 'org.mariadb.jdbc:mariadb-java-client' annotationProcessor 'org.projectlombok:lombok' testImplementation 'org.springframework.boot:spring-boot-starter-test' -// testImplementation 'org.mybatis.spring.boot:mybatis-spring-boot-starter-test:3.0.2' implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1' - // https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-databind implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.15.2' + implementation group: 'info.debatty', name: 'java-string-similarity', version: '2.0.0' + } tasks.named('test') { diff --git a/backend/epimetheus/gradle/wrapper/gradle-wrapper.properties b/backend/epimetheus/gradle/wrapper/gradle-wrapper.properties index 9f4197d5..62f495df 100644 --- a/backend/epimetheus/gradle/wrapper/gradle-wrapper.properties +++ b/backend/epimetheus/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.2.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaAdapter.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaAdapter.java index 6ac35489..b21faa15 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaAdapter.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaAdapter.java @@ -1,8 +1,11 @@ package uos.capstone.epimetheus.adapter; import reactor.core.publisher.Flux; -import uos.capstone.epimetheus.dtos.LlamaResponse; +import reactor.core.publisher.Mono; +import uos.capstone.epimetheus.dtos.LlamaStepResponse; public interface LlamaAdapter { - Flux fetchDataAsStream(String json); + Flux getAllTaskSteps(String json); + + Mono getVectorFromSentence(String sentence); } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaServerStreamAdapter.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaServerStreamAdapter.java index fcd8f443..0d3f2507 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaServerStreamAdapter.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/adapter/LlamaServerStreamAdapter.java @@ -14,9 +14,8 @@ import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.publisher.Flux; -import uos.capstone.epimetheus.dtos.LlamaRequest; -import uos.capstone.epimetheus.dtos.LlamaResponse; -import uos.capstone.epimetheus.dtos.LlamaRequestMessage; +import reactor.core.publisher.Mono; +import uos.capstone.epimetheus.dtos.*; import java.io.IOException; @@ -32,22 +31,24 @@ public class LlamaServerStreamAdapter implements LlamaAdapter{ private final ObjectMapper objectMapper; - @Value("${llama.url}") - String url; + @Value("${llama.generate.step}") + String stepGenerateUrl; + @Value("${llama.generate.vector}") + String vectorGenerateUrl; @Value("${prompt}") Resource prompt; - private String requestBodyBuilder(String task){ + private String stepGenerateRequestBodyBuilder(String task){ Gson gson = new Gson(); - LlamaRequest request = LlamaRequest.builder() + LlamaStepRequest request = LlamaStepRequest.builder() .max_tokens(1024) .temperature(0) .messages(List.of( - LlamaRequestMessage.builder() - .content(readTextFile()) + LlamaPromptRequestMessage.builder() + .content(readPrompt()) .role("system") .build(), - LlamaRequestMessage.builder() + LlamaPromptRequestMessage.builder() .content(task) .role("user") .build() @@ -58,7 +59,7 @@ private String requestBodyBuilder(String task){ return gson.toJson(request); } - private String readTextFile(){ + private String readPrompt(){ try { InputStream inputStream = prompt.getInputStream(); return new String(FileCopyUtils.copyToByteArray(inputStream)); @@ -70,23 +71,25 @@ private String readTextFile(){ } @Override - public Flux fetchDataAsStream(String json){ + public Flux getAllTaskSteps(String task){ - String body = requestBodyBuilder(json); + String body = stepGenerateRequestBodyBuilder(task); try{ + log.info("Step Generate Request to Llama with task: " + task); return webClient.post() - .uri(url) + .uri(stepGenerateUrl) .contentType(MediaType.APPLICATION_JSON) .body(BodyInserters.fromValue(body)) .retrieve() .bodyToFlux(String.class) + .doOnError(e -> log.info("Error occurred while making web request", e)) .flatMap(responseString -> { if ("[DONE]".equals(responseString.trim())) { - return Flux.just(new LlamaResponse()); + return Flux.just(LlamaStepResponse.eof()); } else { try { - LlamaResponse llamaResponse = objectMapper.readValue(responseString, LlamaResponse.class); - return Flux.just(llamaResponse); + LlamaStepResponse llamaStepResponse = objectMapper.readValue(responseString, LlamaStepResponse.class); + return Flux.just(llamaStepResponse); } catch (JsonProcessingException e) { return Flux.error(e); } @@ -101,4 +104,30 @@ public Flux fetchDataAsStream(String json){ } } + + @Override + public Mono getVectorFromSentence(String sentence) { + String body = vectorGenerateRequestBodyBuilder(sentence); + + return webClient.post() + .uri(vectorGenerateUrl) + .contentType(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromValue(body)) + .retrieve() + .bodyToMono(LlamaVectorResponse.class) + .map(LlamaVectorResponse::getVector); + } + + + private String vectorGenerateRequestBodyBuilder(String sentence) { + Gson gson = new Gson(); + + LlamaVectorRequest llamaVectorRequest = LlamaVectorRequest.builder() + .input(sentence) + .build(); + + return gson.toJson(llamaVectorRequest); + } + + } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/controller/TaskController.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/controller/TaskController.java index 076e0a17..27e0c75a 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/controller/TaskController.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/controller/TaskController.java @@ -10,6 +10,7 @@ import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux; import uos.capstone.epimetheus.dtos.TaskStep; +import uos.capstone.epimetheus.dtos.llamaTasks.SubTaskCode; import uos.capstone.epimetheus.dtos.llamaTasks.SubTaskResolver; import uos.capstone.epimetheus.service.TaskSerivce; @@ -23,11 +24,13 @@ public class TaskController { @GetMapping(path = "/tasks", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux getTask(@RequestParam String task) { + log.info("[/tasks] Task : " + task); return taskSerivce.getSubTaskListInStream(task); } @PostMapping(path = "/save") - public ResponseEntity saveCode(@RequestBody TaskStep taskStep){ + public ResponseEntity saveCode(@RequestBody TaskStep taskStep) { + log.info("[/save] Save Code : " + taskStep); String response = taskSerivce.saveCode(taskStep); HttpStatusCode status; if(response.equals("not code")){ @@ -39,4 +42,10 @@ public ResponseEntity saveCode(@RequestBody TaskStep taskStep){ } return ResponseEntity.status(status).body(response); } + + @GetMapping("/code") + public SubTaskCode getSimilar(@RequestParam String input) { + log.info("[/code] Similar Task Input : " + input); + return taskSerivce.getSimilarCode(input); + } } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaRequestMessage.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaPromptRequestMessage.java similarity index 68% rename from backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaRequestMessage.java rename to backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaPromptRequestMessage.java index a3d79014..a2c43403 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaRequestMessage.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaPromptRequestMessage.java @@ -4,13 +4,13 @@ import lombok.Getter; @Getter -public class LlamaRequestMessage { +public class LlamaPromptRequestMessage { private final String content; private final String role; @Builder - public LlamaRequestMessage(String content, String role) { + public LlamaPromptRequestMessage(String content, String role) { this.content = content; this.role = role; } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaResponse.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaResponse.java deleted file mode 100644 index 7bd2c56d..00000000 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaResponse.java +++ /dev/null @@ -1,29 +0,0 @@ -package uos.capstone.epimetheus.dtos; - -import lombok.Getter; - -import java.util.List; - -@Getter -public class LlamaResponse { - private List choices; - - public String parseContent() { - try { - return choices.get(0).getDelta().getContent(); - } catch (NullPointerException e) { - return ""; - } - } - - -} -@Getter -class Choice { - private Delta delta; - -} -@Getter -class Delta { - private String content = ""; -} diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaRequest.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaStepRequest.java similarity index 64% rename from backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaRequest.java rename to backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaStepRequest.java index b8435ebb..328b1423 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaRequest.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaStepRequest.java @@ -6,15 +6,15 @@ import java.util.List; @Getter -public class LlamaRequest { +public class LlamaStepRequest { - private final List messages; + private final List messages; private final int temperature; private final int max_tokens; private final boolean stream; @Builder - public LlamaRequest(List messages, int temperature, int max_tokens, boolean stream) { + public LlamaStepRequest(List messages, int temperature, int max_tokens, boolean stream) { this.messages = messages; this.temperature = temperature; this.max_tokens = max_tokens; diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaStepResponse.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaStepResponse.java new file mode 100644 index 00000000..2c7abb2d --- /dev/null +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaStepResponse.java @@ -0,0 +1,47 @@ +package uos.capstone.epimetheus.dtos; + +import lombok.Getter; +import lombok.NoArgsConstructor; + +import java.util.ArrayList; +import java.util.List; + +@Getter +public class LlamaStepResponse { + private List choices; + + public String parseContent() { + try { + return choices.get(0).getDelta().getContent(); + } catch (NullPointerException e) { + return ""; + } + } + + public static LlamaStepResponse eof() { + LlamaStepResponse response = new LlamaStepResponse(); + response.choices = new ArrayList<>(); + response.choices.add(new Choice("[DONE]")); + + return response; + } +} +@Getter +@NoArgsConstructor +class Choice { + private Delta delta; + + Choice(String content) { + this.delta = new Delta(content); + } + +} +@Getter +@NoArgsConstructor +class Delta { + private String content = ""; + + Delta(String content) { + this.content = content; + } +} diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaVectorRequest.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaVectorRequest.java new file mode 100644 index 00000000..83b28f9d --- /dev/null +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaVectorRequest.java @@ -0,0 +1,15 @@ +package uos.capstone.epimetheus.dtos; + +import lombok.Builder; +import lombok.Getter; + +@Getter +public class LlamaVectorRequest { + + private String input; + + @Builder + public LlamaVectorRequest(String input) { + this.input = input; + } +} diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaVectorResponse.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaVectorResponse.java new file mode 100644 index 00000000..8199a9f4 --- /dev/null +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/LlamaVectorResponse.java @@ -0,0 +1,35 @@ +package uos.capstone.epimetheus.dtos; + +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.Getter; + +import java.util.List; + +@Getter +public class LlamaVectorResponse { + + private List data; + private TokenUsage usage; + + public double[] getVector() { + if((data != null ? data.size() : 0) != 1) { + throw new RuntimeException("Invalid Data Came"); + } + + return data.get(0).getEmbedding(); + } +} + +@Getter +class EmbeddingData { + + private String object; + private double[] embedding; + private int index; +} + +@Getter +class TokenUsage { + private int prompt_tokens; + private int total_tokens; +} diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/TaskStep.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/TaskStep.java index e2ed6eee..67a31a27 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/TaskStep.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/TaskStep.java @@ -2,28 +2,35 @@ import lombok.Builder; import org.springframework.data.annotation.Id; +import org.springframework.data.mongodb.core.mapping.DBRef; import org.springframework.data.mongodb.core.mapping.Document; import uos.capstone.epimetheus.dtos.llamaTasks.CodeLanguage; +import java.util.Arrays; +import java.util.Objects; + @Document(collection = "subtask") public class TaskStep { @Id String title; + double[] values; CodeLanguage language; String code; @Builder - public TaskStep(String title, CodeLanguage language, String code){ + public TaskStep(String title, double[] values, CodeLanguage language, String code){ this.title = title; + this.values = values; this.language = language; this.code = code; } - public static TaskStep of(String title) { + public static TaskStep of(String title, double[] vector) { return TaskStep.builder() .title(title) + .values(vector) .language(CodeLanguage.DEFAULT) .code("") .build(); @@ -44,4 +51,33 @@ public String getLanguage() { return language.getLanguage(); } + + public double[] getValues() { + return this.values; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TaskStep taskStep = (TaskStep) o; + return Objects.equals(title, taskStep.title) && Arrays.equals(values, taskStep.values) && language == taskStep.language && Objects.equals(code, taskStep.code); + } + + @Override + public int hashCode() { + int result = Objects.hash(title, language, code); + result = 31 * result + Arrays.hashCode(values); + return result; + } + + @Override + public String toString() { + return "TaskStep{" + + "title='" + title + '\'' + + ", values=" + Arrays.toString(values) + + ", language=" + language + + ", code='" + code + '\'' + + '}'; + } } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/llamaTasks/SubTaskCode.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/llamaTasks/SubTaskCode.java index 49017e00..651b3b40 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/llamaTasks/SubTaskCode.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/dtos/llamaTasks/SubTaskCode.java @@ -31,4 +31,8 @@ public String getProperty() { public String getLanguage() { return language.getLanguage(); } + + public String getCode() { + return code; + } } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/CosineSimilarityService.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/CosineSimilarityService.java new file mode 100644 index 00000000..90747f6b --- /dev/null +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/CosineSimilarityService.java @@ -0,0 +1,55 @@ +package uos.capstone.epimetheus.service; + +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; +import uos.capstone.epimetheus.adapter.LlamaAdapter; +import uos.capstone.epimetheus.dtos.TaskStep; + +import java.util.Comparator; +import java.util.Optional; + +@Service +@RequiredArgsConstructor +public class CosineSimilarityService implements SimilarityService { + + private final LlamaAdapter llamaAdapter; + private final DatabaseService databaseService; + + @Override + public TaskStep getSimilarStep(String step) { + + double[] inputVector = llamaAdapter.getVectorFromSentence(step).block(); + Optional similar = databaseService.getAllData().stream() + .filter(data -> data.getValues() != null && data.getValues().length == inputVector.length) + .filter(data -> cosineSimilarity(inputVector, data.getValues()) >= 0.8) + .sorted(Comparator.comparing(data -> (-1) * cosineSimilarity(inputVector, data.getValues()))) + .findFirst(); + + return similar.orElseGet(() -> databaseService.saveByTitle(step, inputVector)); + } + + + private double cosineSimilarity(double[] input, double[] toCompare) { + if (input.length != toCompare.length) { + throw new IllegalArgumentException("Vectors must have the same length"); + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < input.length; i++) { + dotProduct += input[i] * toCompare[i]; + normA += input[i] * input[i]; + normB += toCompare[i] * toCompare[i]; + } + + if (normA == 0 || normB == 0) { + // If one of the vectors has a magnitude of 0, then similarity is undefined (returning 0 or a special case value might be appropriate) + return 0; + } + + double result = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + return result; + } +} diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/DatabaseService.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/DatabaseService.java index a9eba070..bc66d122 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/DatabaseService.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/DatabaseService.java @@ -2,8 +2,13 @@ import uos.capstone.epimetheus.dtos.TaskStep; +import java.util.List; + public interface DatabaseService { - TaskStep getTaskStepByTitle(String id); + + TaskStep saveByTitle(String step, double[] vector); void saveCode(TaskStep taskStep); + + List getAllData(); } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/MongoDBServiceImpl.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/MongoDBServiceImpl.java index 69b4c769..d54c2d74 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/MongoDBServiceImpl.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/MongoDBServiceImpl.java @@ -6,6 +6,8 @@ import uos.capstone.epimetheus.dtos.TaskStep; import uos.capstone.epimetheus.repository.MongoDBRepository; +import java.util.List; + @Service @RequiredArgsConstructor @Log4j2 @@ -13,8 +15,8 @@ public class MongoDBServiceImpl implements DatabaseService { private final MongoDBRepository mongoRepository; @Override - public TaskStep getTaskStepByTitle(String id){ - return mongoRepository.findById(id).orElse(mongoRepository.save(TaskStep.of(id))); + public TaskStep saveByTitle(String step, double[] vector) { + return mongoRepository.save(TaskStep.of(step, vector)); } @Override @@ -22,4 +24,9 @@ public void saveCode(TaskStep taskStep){ mongoRepository.save(taskStep); } + @Override + public List getAllData() { + return mongoRepository.findAll(); + } + } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/SimilarityService.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/SimilarityService.java new file mode 100644 index 00000000..092b3cee --- /dev/null +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/SimilarityService.java @@ -0,0 +1,7 @@ +package uos.capstone.epimetheus.service; + +import uos.capstone.epimetheus.dtos.TaskStep; + +public interface SimilarityService { + TaskStep getSimilarStep(String step); +} diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskSerivce.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskSerivce.java index 95fb9123..470b3dc9 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskSerivce.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskSerivce.java @@ -2,6 +2,7 @@ import reactor.core.publisher.Flux; import uos.capstone.epimetheus.dtos.TaskStep; +import uos.capstone.epimetheus.dtos.llamaTasks.SubTaskCode; import uos.capstone.epimetheus.dtos.llamaTasks.SubTaskResolver; public interface TaskSerivce { @@ -9,4 +10,6 @@ public interface TaskSerivce { Flux getSubTaskListInStream(String task); String saveCode(TaskStep taskStep); + + SubTaskCode getSimilarCode(String step); } diff --git a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskServiceStreamImpl.java b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskServiceStreamImpl.java index 30dc5f6e..a1d1d65d 100644 --- a/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskServiceStreamImpl.java +++ b/backend/epimetheus/src/main/java/uos/capstone/epimetheus/service/TaskServiceStreamImpl.java @@ -1,25 +1,27 @@ package uos.capstone.epimetheus.service; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import uos.capstone.epimetheus.adapter.LlamaAdapter; -import uos.capstone.epimetheus.dtos.LlamaResponse; import uos.capstone.epimetheus.dtos.TaskStep; import uos.capstone.epimetheus.dtos.llamaTasks.*; - import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; @RequiredArgsConstructor @Service -public class TaskServiceStreamImpl implements TaskSerivce{ +@Log4j2 +public class TaskServiceStreamImpl implements TaskSerivce { private final LlamaAdapter llamaAdapter; - private final DatabaseService databaseService; + private final SimilarityService similarityService; + + private final String stopWord = "ginger"; @Override public Flux getSubTaskListInStream(String task) { @@ -31,108 +33,75 @@ public Flux getSubTaskListInStream(String task) { Pattern pattern = Pattern.compile("!!(\\d+)\\."); - return Flux.create(sink -> { - llamaAdapter.fetchDataAsStream(task) - .map(LlamaResponse::parseContent) - .doOnNext(data -> { - - if (buffer.indexOf("Intro:") != -1) { - state.set(0); - buffer.setLength(0); - } - else if (buffer.indexOf("ginger") != -1) { - if(state.get() == 0) { - intro.append(gingerParse(buffer)); + return llamaAdapter.getAllTaskSteps(task).flatMap(llamaStepResponse -> { + Flux subTask = Flux.empty(); + String data = llamaStepResponse.parseContent(); + Matcher matcher = pattern.matcher(buffer); + boolean patternFound = matcher.find(); + if (buffer.indexOf("Intro:") != -1 || patternFound && state.get() == 0) { + state.set(0); + buffer.setLength(0); + } else if (buffer.indexOf("Outro:") != -1) { + state.set(3); + buffer.setLength(0); + stepNo.set(0); + } else if (data.equals("[DONE]")) { + subTask = Flux.just(SubTaskWrap.builder() + .stepNo(0) + .wrapper(endOfFluxParse(buffer)) + .property(ResponseStreamProperty.OUTRO) + .build()); + } else if (buffer.indexOf(stopWord) == -1 && !patternFound) { + subTask = Flux.empty(); + } else { + boolean type = buffer.indexOf(stopWord) != -1; + String content = type ? stopWordParse(buffer) : matcherParse(buffer, matcher.start()); + switch (state.get()) { + case 0: + intro.append(content); buffer.setLength(0); - sink.next(SubTaskWrap.builder() + subTask = Flux.just(SubTaskWrap.builder() .stepNo(0) - .wrapper(intro.toString()) + .wrapper(content) .property(ResponseStreamProperty.INTRO) .build()); - } - else if (state.get() == 1) { - String title = gingerParse(buffer); - sink.next(SubTaskTitle.builder() + break; + case 1: + if(stepNo.get() == 0) { + subTask = Flux.empty(); + break; + } + subTask = Flux.just(SubTaskTitle.builder() .stepNo(stepNo.get()) - .title(title) + .title(content) .property(ResponseStreamProperty.TITLE) .build()); - TaskStep taskStep = databaseService.getTaskStepByTitle(title); - sink.next(SubTaskCode.builder() + break; + case 2: + if(stepNo.get() == 0) + break; + subTask = Flux.just(SubTaskDescription.builder() .stepNo(stepNo.get()) - .code(taskStep.getCode()) - .property(ResponseStreamProperty.CODE) - .language(CodeLanguage.of(taskStep.getLanguage())) - .build()); - } else if (state.get() == 2) { - String description = gingerParse(buffer); - sink.next(SubTaskDescription.builder() - .stepNo(stepNo.get()) - .description(description) + .description(content) .property(ResponseStreamProperty.DESCRIPTION) .build()); - } - stepNo.set(0); - buffer.setLength(0); - state.incrementAndGet(); - } - else if (buffer.indexOf("Outro:") != -1) { - state.set(3); - buffer.setLength(0); - stepNo.set(0); + break; + default: + subTask = Flux.error(new RuntimeException("Invalid Prompt")); } - else { - Matcher matcher = pattern.matcher(buffer); - if(matcher.find()) { - if (stepNo.get() == 0) { - buffer.setLength(0); - } else if (state.get() == 1) { - String title = matcerParse(buffer, matcher.start()); - sink.next(SubTaskTitle.builder() - .stepNo(stepNo.get()) - .title(title) - .property(ResponseStreamProperty.TITLE) - .build()); - TaskStep taskStep = databaseService.getTaskStepByTitle(title); - sink.next(SubTaskCode.builder() - .stepNo(stepNo.get()) - .code(taskStep.getCode()) - .property(ResponseStreamProperty.CODE) - .language(CodeLanguage.of(taskStep.getLanguage())) - .build()); - } else if (state.get() == 2) { - String description = matcerParse(buffer, matcher.start()); - sink.next(SubTaskDescription.builder() - .stepNo(stepNo.get()) - .description(description) - .property(ResponseStreamProperty.DESCRIPTION) - .build()); - } - buffer.setLength(0); - stepNo.incrementAndGet(); - } + if (type) { + stepNo.set(0); + state.incrementAndGet(); } - buffer.append(data); - }) - .doOnComplete(() -> { - if(state.get() == 3) { - sink.next(SubTaskWrap.builder() - .stepNo(stepNo.get()) - .wrapper(buffer.toString().trim()) - .property(ResponseStreamProperty.OUTRO) - .build()); - }else{ - sink.next(SubTaskWrap.builder() - .stepNo(0) - .wrapper(buffer.toString().trim()) - .property(ResponseStreamProperty.ERROR) - .build()); + else { + stepNo.incrementAndGet(); } - sink.complete(); - }).subscribe(); - }); - + buffer.setLength(0); + } + buffer.append(data); + return subTask; + }); } @Override @@ -150,12 +119,27 @@ public String saveCode(TaskStep taskStep){ } } - private String gingerParse(StringBuffer stringBuffer){ + private String stopWordParse(StringBuffer stringBuffer){ return stringBuffer.substring(0, stringBuffer.indexOf("ginger")).trim(); } - private String matcerParse(StringBuffer stringBuffer, int m){ + private String endOfFluxParse(StringBuffer stringBuffer){ + return stringBuffer.toString().trim(); + } + + private String matcherParse(StringBuffer stringBuffer, int m){ return stringBuffer.substring(0, m).trim(); } + + @Override + public SubTaskCode getSimilarCode(String step) { + TaskStep stepCode = similarityService.getSimilarStep(step); + + return SubTaskCode.builder() + .code(stepCode.getCode()) + .property(ResponseStreamProperty.CODE) + .language(CodeLanguage.of(stepCode.getLanguage())) + .build(); + } }