Skip to content

Commit

Permalink
limit rag file amount (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangtianyu authored Oct 30, 2024
1 parent 32ce0ea commit 475412a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import com.zhongan.devpilot.util.DevPilotMessageBundle;
import com.zhongan.devpilot.util.JsonUtils;
import com.zhongan.devpilot.util.MessageUtil;
import com.zhongan.devpilot.util.PsiElementUtils;
import com.zhongan.devpilot.util.TokenUtils;
import com.zhongan.devpilot.webview.model.CodeReferenceModel;
import com.zhongan.devpilot.webview.model.EmbeddedModel;
Expand Down Expand Up @@ -151,9 +150,9 @@ public void smartChat(Integer sessionType, String msgType, Map<String, String> d

if (localRef != null) {
ApplicationManager.getApplication().runReadAction(() -> {
var relatedCode = PsiElementUtils.transformElementToString(localRef);
newMap.put("relatedContext", relatedCode);
// newMap.put("additionalRelatedContext", null);
var language = messageModel.getCodeRef() == null ? null : messageModel.getCodeRef().getLanguageId();
FileAnalyzeProviderFactory.getProvider(language)
.buildRelatedContextDataMap(project, messageModel.getCodeRef(), localRef, null, data);
localRefs[0] = CodeReferenceModel.getCodeRefListFromPsiElement(localRef, EditorActionEnum.getEnumByName(msgType));
});
}
Expand Down Expand Up @@ -212,9 +211,9 @@ public void regenerateSmartChat(MessageModel messageModel, Consumer<String> call

if (localRef != null) {
ApplicationManager.getApplication().runReadAction(() -> {
var relatedCode = PsiElementUtils.transformElementToString(localRef);
data.put("relatedContext", relatedCode);
// data.put("additionalRelatedContext", null);
var language = messageModel.getCodeRef() == null ? null : messageModel.getCodeRef().getLanguageId();
FileAnalyzeProviderFactory.getProvider(language)
.buildRelatedContextDataMap(project, messageModel.getCodeRef(), localRef, null, data);
localRefs[0] = CodeReferenceModel.getCodeRefListFromPsiElement(localRef, messageModel.getCodeRef().getType());
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public void buildTestDataMap(Project project, Editor editor, Map<String, String>
// default do nothing
}

@Override
public void buildRelatedContextDataMap(Project project, CodeReferenceModel codeReference, List<PsiElement> localRef, List<PsiElement> remoteRef, Map<String, String> data) {

}

@Override
public List<PsiElement> callLocalRag(Project project, DevPilotCodePrediction codePrediction) {
return List.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ public interface FileAnalyzeProvider {

void buildTestDataMap(Project project, Editor editor, Map<String, String> data);

void buildRelatedContextDataMap(Project project, CodeReferenceModel codeReference, List<PsiElement> localRef, List<PsiElement> remoteRef, Map<String, String> data);

List<PsiElement> callLocalRag(Project project, DevPilotCodePrediction codePrediction);
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ public void buildTestDataMap(Project project, Editor editor, Map<String, String>
}
}

@Override
public void buildRelatedContextDataMap(Project project, CodeReferenceModel codeReference, List<PsiElement> localRef, List<PsiElement> remoteRef, Map<String, String> data) {
String packageName = null;

if (codeReference != null && codeReference.getFileUrl() != null) {
var psiJavaFile = PsiElementUtils.getPsiJavaFileByFilePath(project, codeReference.getFileUrl());
if (psiJavaFile != null) {
packageName = psiJavaFile.getPackageName();
}
}

var relatedCode = PsiElementUtils.transformElementToString(localRef, packageName);
data.put("relatedContext", relatedCode);
// data.put("additionalRelatedContext", null);
}

@Override
public List<PsiElement> callLocalRag(Project project, DevPilotCodePrediction codePrediction) {
return PsiElementUtils.contextRecall(project, codePrediction);
Expand Down
70 changes: 63 additions & 7 deletions src/main/java/com/zhongan/devpilot/util/PsiElementUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.jar.Attributes;
import java.util.jar.JarFile;
import java.util.jar.Manifest;
import java.util.stream.Collectors;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -47,6 +48,8 @@
public class PsiElementUtils {
private static final int MAX_LINE_COUNT = 1000;

private static final int MAX_CODE_SNIPPET_COUNT = 10;

public static String getFullClassName(@NotNull PsiElement element) {
if (element instanceof PsiMethod) {
var psiClass = ((PsiMethod) element).getContainingClass();
Expand All @@ -60,13 +63,12 @@ public static String getFullClassName(@NotNull PsiElement element) {
return null;
}

public static <T extends PsiElement> String transformElementToString(Collection<T> elements) {
public static String transformElementToString(Collection<PsiElement> elements, String packageName) {
var result = new StringBuilder();

for (T element : elements) {
if (shouldIgnorePsiElement(element)) {
continue;
}
var filteredElements = filterElements(elements, packageName);

for (var element : filteredElements) {
if (element instanceof PsiClass) {
PsiClass psiClass = (PsiClass) element;
result.append("Class: ").append(psiClass.getQualifiedName()).append("\n\n");
Expand Down Expand Up @@ -355,7 +357,61 @@ public static List<PsiElement> contextRecall(Project project, DevPilotCodePredic
return doRecall(project, finalRefs);
}

public static String filterLargeElement(PsiElement element) {
// filter elements when elements amount larger than MAX_CODE_SNIPPET_COUNT
private static Collection<PsiElement> filterElements(Collection<PsiElement> elements, String packageName) {
if (CollectionUtils.isEmpty(elements)) {
return elements;
}

elements = elements.stream()
.filter(element -> !shouldIgnorePsiElement(element)).collect(Collectors.toList());

if (elements.size() <= MAX_CODE_SNIPPET_COUNT) {
return elements;
}

if (StringUtils.isEmpty(packageName)) {
return elements.stream().limit(MAX_CODE_SNIPPET_COUNT).collect(Collectors.toList());
}

var resultList = new ArrayList<>(elements);

// Sort elements by the length of the common prefix with the package name
// The longer the common prefix, the higher the priority
resultList.sort((o1, o2) -> {
var m1 = maxCommonPrefixLength(o1, packageName);
var m2 = maxCommonPrefixLength(o2, packageName);
return m2 - m1;
});

return resultList.stream().limit(MAX_CODE_SNIPPET_COUNT).collect(Collectors.toList());
}

private static int maxCommonPrefixLength(PsiElement element, String packageName) {
String elementPackage = null;
if (element instanceof PsiClass) {
elementPackage = ((PsiClass) element).getQualifiedName();
} else if (element instanceof PsiMethod) {
var method = ((PsiMethod) element);
if (method.getContainingClass() != null) {
elementPackage = method.getContainingClass().getQualifiedName();
}
}

if (StringUtils.isEmpty(elementPackage) || StringUtils.isEmpty(packageName)) {
return 0;
}

var minLength = Math.min(elementPackage.length(), packageName.length());
for (int i = 0; i < minLength; i++) {
if (elementPackage.charAt(i) != packageName.charAt(i)) {
return i;
}
}
return minLength;
}

private static String filterLargeElement(PsiElement element) {
var lineCount = getLineCount(element);
if (lineCount > MAX_LINE_COUNT) {
return simplifyElement(element);
Expand All @@ -364,7 +420,7 @@ public static String filterLargeElement(PsiElement element) {
}
}

public static String simplifyElement(PsiElement element) {
private static String simplifyElement(PsiElement element) {
if (element instanceof PsiClass) {
var psiClass = (PsiClass) element;
return simplifyClass(psiClass);
Expand Down

0 comments on commit 475412a

Please sign in to comment.