diff --git a/.gitignore b/.gitignore index 8291da049..bd582b444 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ lib/JUnitRunner.jar *.DS_Store .intellijPlatform/ +.kotlin/ + # Test projects src/test/resources/project/.idea/ diff --git a/build.gradle.kts b/build.gradle.kts index 61098b4e9..5a06f7bf7 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -98,6 +98,8 @@ if (spaceCredentialsProvided()) { // add build of new source set as the part of UI testing tasks.prepareTestSandbox.configure { + duplicatesStrategy = DuplicatesStrategy.EXCLUDE + dependsOn(hasGrazieAccess.jarTaskName) from( tasks @@ -114,6 +116,8 @@ if (spaceCredentialsProvided()) { } // add build of new source set as the part of pluginBuild process tasks.prepareSandbox.configure { + duplicatesStrategy = DuplicatesStrategy.EXCLUDE + dependsOn(hasGrazieAccess.jarTaskName) from( tasks @@ -212,12 +216,13 @@ dependencies { // https://mvnrepository.com/artifact/net.jqwik/jqwik testImplementation("net.jqwik:jqwik:1.6.5") + // https://mvnrepository.com/artifact/org.jetbrains.kotlinx/kotlinx-coroutines-test + testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.10.1") + // https://mvnrepository.com/artifact/com.github.javaparser/javaparser-symbol-solver-core implementation("com.github.javaparser:javaparser-symbol-solver-core:3.24.2") // https://mvnrepository.com/artifact/org.jetbrains.kotlin/kotlin-test implementation("org.jetbrains.kotlin:kotlin-test:1.8.0") - - implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3") } // Configure Gradle IntelliJ Plugin - read more: // Configure Gradle IntelliJ Plugin - read more: https://github.com/JetBrains/gradle-intellij-plugin diff --git a/core/build.gradle.kts b/core/build.gradle.kts index f9ae862d2..5e67a3c34 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -1,6 +1,7 @@ plugins { - kotlin("jvm") + kotlin("jvm") version "2.1.0" `maven-publish` + kotlin("plugin.serialization") version "2.1.0" } group = "org.jetbrains.research" @@ -18,6 +19,13 @@ dependencies { compileOnly(kotlin("stdlib")) implementation("io.github.oshai:kotlin-logging-jvm:6.0.3") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.1") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3") + + val ktorVersion = "2.3.13" + implementation("io.ktor:ktor-client-core:$ktorVersion") + implementation("io.ktor:ktor-client-cio:$ktorVersion") + implementation("io.ktor:ktor-client-logging:$ktorVersion") } tasks.test { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt index b5decde39..299fddc7d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/ChatMessage.kt @@ -1,19 +1,20 @@ package org.jetbrains.research.testspark.core.data -open class ChatMessage protected constructor( +data class ChatMessage( val role: ChatRole, - val content: String, + val contentBuilder: StringBuilder, ) { enum class ChatRole { User, Assistant, } -} -class ChatUserMessage( - content: String, -) : ChatMessage(ChatRole.User, content) + val content: String + get() = contentBuilder.toString() + + companion object { + fun createUserMessage(message: String) = ChatMessage(ChatRole.User, StringBuilder(message)) -class ChatAssistantMessage( - content: String, -) : ChatMessage(ChatRole.Assistant, content) + fun createAssistantMessage(message: String) = ChatMessage(ChatRole.Assistant, StringBuilder(message)) + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/LlmError.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/LlmError.kt index 87c06b0af..bb4b4e119 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/LlmError.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/LlmError.kt @@ -13,12 +13,12 @@ sealed class LlmError( data object EmptyLlmResponse : LlmError() - data object TestSuiteParsingError : LlmError() + data class TestSuiteParsingError( + override val cause: Throwable?, + ) : LlmError(cause) data object NoCompilableTestCasesGenerated : LlmError() - data object FailedToSaveTestFiles : LlmError() - data object CompilationError : LlmError() data object UnsetTokenError : LlmError() diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/Result.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/Result.kt index c7593b313..a5b91864d 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/Result.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/error/Result.kt @@ -1,13 +1,13 @@ package org.jetbrains.research.testspark.core.error -sealed interface Result { +sealed interface Result { data class Success( val data: D, - ) : Result + ) : Result - data class Failure( - val error: E, - ) : Result + data class Failure( + val error: TestSparkError, + ) : Result fun getDataOrNull(): D? = if (this is Success) data else null @@ -15,13 +15,13 @@ sealed interface Result { fun isFailure(): Boolean = this is Failure - fun mapData(transform: (D) -> R): Result = + fun mapData(transform: (D) -> R): Result = when (this) { is Success -> Success(transform(data)) is Failure -> Failure(error) } - fun mapError(transform: (E) -> R): Result = + fun mapError(transform: (TestSparkError) -> R): Result = when (this) { is Success -> Success(data) is Failure -> Failure(transform(error)) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CommonException.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CommonException.kt index c851e2182..5953d1c92 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CommonException.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CommonException.kt @@ -11,5 +11,6 @@ sealed class CommonException( ) class ProcessCancelledException( - module: TestSparkModule, -) : CommonException(module) + module: TestSparkModule = TestSparkModule.Common, + cause: Throwable? = null, +) : CommonException(module, cause) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CompilerException.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CompilerException.kt index b6ead74cd..28b5b7b61 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CompilerException.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/exception/CompilerException.kt @@ -43,3 +43,5 @@ class ClassFileNotFoundException( val classFilePath: String, val filePath: String, ) : CompilerException() + +class TestSavingFailureException : CompilerException() diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/ChatSessionManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/ChatSessionManager.kt new file mode 100644 index 000000000..837d230d4 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/ChatSessionManager.kt @@ -0,0 +1,79 @@ +package org.jetbrains.research.testspark.core.generation.llm + +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import org.jetbrains.research.testspark.core.data.ChatMessage +import org.jetbrains.research.testspark.core.error.LlmError +import org.jetbrains.research.testspark.core.error.Result +import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams + +class ChatSessionManager( + private val requestManager: RequestManager, + private val llmParams: LlmParams, +) { + private val mutex = Mutex() + private val chatHistory = mutableListOf() + private val log = KotlinLogging.logger {} + + suspend fun request( + prompt: String, + isUserFeedback: Boolean, + ): Flow> { + log.info { "Sending Request..." } + + recordChatMessage(isUserFeedback, ChatMessage.Companion.createUserMessage(message = prompt)) + + val chatHistory = + if (isUserFeedback) { + chatHistory + ChatMessage.createUserMessage(prompt) + } else { + chatHistory + } + + return requestManager + .sendRequest( + llmParams, + chatHistory, + ).map { result -> + val responseString = result.getDataOrNull() + if (responseString != null && responseString.isEmpty()) { + Result.Failure(error = LlmError.EmptyLlmResponse) + } else { + result + } + }.onEach { result -> + val rawText = result.getDataOrNull() + if (rawText.isNullOrEmpty().not()) { + recordChatMessage( + isUserFeedback, + ChatMessage.Companion.createAssistantMessage(message = rawText), + ) + } + } + } + + private suspend fun recordChatMessage( + isUserFeedback: Boolean, + message: ChatMessage, + ) { + if (isUserFeedback) return + mutex.withLock { + when (message.role) { + ChatMessage.ChatRole.User -> chatHistory.add(message) + ChatMessage.ChatRole.Assistant -> { + val lastMessage = chatHistory.lastOrNull() + if (lastMessage != null && lastMessage.role == ChatMessage.ChatRole.Assistant) { + lastMessage.contentBuilder.append(message.content) + } else { + chatHistory.add(message) + } + } + } + } + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt index 3e6549ee6..4ee5447b5 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/LLMWithFeedbackCycle.kt @@ -1,18 +1,16 @@ package org.jetbrains.research.testspark.core.generation.llm import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow import org.jetbrains.research.testspark.core.data.Report import org.jetbrains.research.testspark.core.data.TestCase -import org.jetbrains.research.testspark.core.data.TestSparkModule import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.Result import org.jetbrains.research.testspark.core.error.TestSparkError -import org.jetbrains.research.testspark.core.exception.ProcessCancelledException -import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager +import org.jetbrains.research.testspark.core.exception.TestSavingFailureException import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy -import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator +import org.jetbrains.research.testspark.core.test.ExecutionResult import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestCompiler import org.jetbrains.research.testspark.core.test.TestsAssembler @@ -22,17 +20,6 @@ import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM import java.io.File -/** - * Represents a response (result) of a feedback cycle. - * - * @param generatedTestSuite The test suite generated by LLM. - * @param compilableTestCases The set of compilable test cases generated by LLM. - */ -data class FeedbackResponse( - val generatedTestSuite: TestSuiteGeneratedByLLM, - val compilableTestCases: MutableSet, -) - /** * LLMWithFeedbackCycle class represents a feedback cycle for an LLM. * @@ -41,17 +28,14 @@ data class FeedbackResponse( * @property initialPromptMessage The initial prompt message to start the feedback cycle. * @property promptSizeReductionStrategy The `PromptSizeReductionStrategy` instance used for reducing the prompt size. * @property testSuiteFilename The name of the file in which the test suite is saved in the result path. - * @property packageName The package name for the generated tests. * @property resultPath The temporary path where all the generated tests and their Jacoco report are saved. * @property buildPath All the directories where the compiled code of the project under test is saved. - * @property requestManager The `RequestManager` instance used for making LLM requests. + * @property chatSessionManager A ChatSession manager which holds chat history and manages requests to LLM. * @property testsAssembler The `TestsAssembler` instance used for assembling generated tests. * @property testCompiler The `TestCompiler` instance used for compiling tests. * @property testStorage The `TestsPersistentStorage` instance used for storing generated tests. * @property testsPresenter The `TestsPresenter` instance used for presenting generated tests. - * @property indicator The `CustomProgressIndicator` instance used for tracking progress. * @property requestsCountThreshold The threshold for the maximum number of requests in the feedback cycle. - * @property errorMonitor The `ErrorMonitor` instance used for monitoring errors. */ class LLMWithFeedbackCycle( private val report: Report, @@ -60,226 +44,146 @@ class LLMWithFeedbackCycle( private val promptSizeReductionStrategy: PromptSizeReductionStrategy, // filename in which the test suite is saved in the result path private val testSuiteFilename: String, - private val packageName: String, // temp path where all the generated tests and their jacoco report are saved private val resultPath: String, // all the directories where the compiled code of the project under test is saved. This path will be used as a classpath to run each test case private val buildPath: String, - private val requestManager: RequestManager, + private val chatSessionManager: ChatSessionManager, private val testsAssembler: TestsAssembler, private val testCompiler: TestCompiler, private val testStorage: TestsPersistentStorage, private val testsPresenter: TestsPresenter, - private val indicator: CustomProgressIndicator, private val requestsCountThreshold: Int, - private val errorMonitor: ErrorMonitor = DefaultErrorMonitor(), ) { private val log = KotlinLogging.logger { this::class.java } - private lateinit var generatedTestSuite: TestSuiteGeneratedByLLM - - fun run(onWarningCallback: ((TestSparkError) -> Unit)? = null): Result { - var requestsCount = 0 - var generatedTestsArePassing = false - var nextPromptMessage = initialPromptMessage - - val compilableTestCases: MutableSet = mutableSetOf() - - // collect imports from all responses - val imports: MutableSet = mutableSetOf() - - while (!generatedTestsArePassing) { - requestsCount++ - - log.info { "Iteration #$requestsCount of feedback cycle" } - - // Process stopped checking - if (indicator.isCanceled()) throw ProcessCancelledException(module = TestSparkModule.Llm()) - - if (isLastIteration(requestsCount) && compilableTestCases.isEmpty()) { - // record a report with parsable yet potentially - // non-compilable test cases stored in - // the generated test suite - // TODO: ensure generatedTestSuite is always non-null here - if (::generatedTestSuite.isInitialized) { - recordReport(report, generatedTestSuite.testCases) - } - break - } - // clearing test assembler's collected text on the previous attempts - testsAssembler.clear() - val response: Result = - requestManager.request( - language = language, - prompt = nextPromptMessage, - indicator = indicator, - packageName = packageName, - testsAssembler = testsAssembler, - isUserFeedback = false, - errorMonitor, - ) - - // Process stopped checking - if (indicator.isCanceled()) throw ProcessCancelledException(module = TestSparkModule.Llm()) - - when (response) { - is Result.Success -> { - log.info { "Test suite generated successfully: ${response.data}" } - // check that there are some test cases generated - if (response.data.testCases.isEmpty()) { - onWarningCallback?.invoke(LlmError.EmptyLlmResponse) - - nextPromptMessage = - "You have provided an empty answer! Please answer my previous question with the same formats." + fun run(): Flow> = + flow { + var iteration = 0 + var nextPromptMessage = initialPromptMessage + val generatedTestSuites: MutableList = mutableListOf() + + while (iteration < requestsCountThreshold) { + iteration++ + log.info { "Iteration #$iteration of feedback cycle" } + + // ensure the testsAssembler is empty before each iteration + testsAssembler.clear() + + val chunks: Flow> = + chatSessionManager.request( + prompt = nextPromptMessage, + isUserFeedback = false, + ) + val testSuiteResult: Result = chunks.collectChunks(testsAssembler) + + when (testSuiteResult) { + is Result.Success -> log.info { "Test suite generated successfully: ${testSuiteResult.data}" } + + is Result.Failure -> { + log.info { "Cannot parse a test suite from the LLM response. LLM response: '$testSuiteResult'" } + emit(testSuiteResult) + nextPromptMessage = generatePromptMessage(testSuiteResult.error) ?: break + + /** + * The current attempt does not count as a failure since it was rejected due to the prompt size + * exceeding the threshold + */ + if (testSuiteResult.error is LlmError.PromptTooLong) iteration-- continue } } - is Result.Failure -> { - when (response.error) { - is LlmError.EmptyLlmResponse -> { - nextPromptMessage = - "You have provided an empty answer! Please, answer my previous question with the same formats" - continue - } - - is LlmError.PromptTooLong -> { - if (promptSizeReductionStrategy.isReductionPossible()) { - nextPromptMessage = promptSizeReductionStrategy.reduceSizeAndGeneratePrompt() - /** - * The current attempt does not count as a failure - * since it was rejected due to the prompt size - * exceeding the threshold - */ - requestsCount-- - continue - } else { - return Result.Failure(error = LlmError.PromptTooLong) - } - } - - is LlmError.TestSuiteParsingError -> { - onWarningCallback?.invoke(LlmError.TestSuiteParsingError) - log.info { "Cannot parse a test suite from the LLM response. LLM response: '$response'" } - - nextPromptMessage = "The provided code is not parsable. Please, generate the correct code" - continue - } + val testSuite = testSuiteResult.data + generatedTestSuites.add(testSuite) + compileTestCases(testSuite) - else -> return response - } + if (testSuite.testCases.any { it.isCompilable.not() }) { + log.info { "Non-compilable test suite: \n${testsPresenter.representTestSuite(testSuite)}" } + emit(Result.Failure(LlmError.CompilationError)) + nextPromptMessage = generateCompilationErrorPrompt(testSuite) continue } - } - - generatedTestSuite = response.data - - // update imports list - imports.addAll(generatedTestSuite.imports) - - // Process stopped checking - if (indicator.isCanceled()) throw ProcessCancelledException(module = TestSparkModule.Llm()) - - // Save the generated TestSuite into a temp file - val generatedTestCasesPaths: MutableList = mutableListOf() - - if (isLastIteration(requestsCount)) { - generatedTestSuite.updateTestCases(compilableTestCases.toMutableList()) - } else { - for (testCaseIndex in generatedTestSuite.testCases.indices) { - val testCaseFilename = - when (language) { - SupportedLanguage.Java -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.java" - SupportedLanguage.Kotlin -> "${getClassWithTestCaseName(generatedTestSuite.testCases[testCaseIndex].name)}.kt" - } - val testCaseRepresentation = testsPresenter.representTestCase(generatedTestSuite, testCaseIndex) - - val saveFilepath = - testStorage.saveGeneratedTest( - generatedTestSuite.packageName, - testCaseRepresentation, - resultPath, - testCaseFilename, - ) + break + } - generatedTestCasesPaths.add(saveFilepath) - } + log.info { "Result is compilable" } + val resultingTestSuite = joinTestSuites(generatedTestSuites) + if (resultingTestSuite != null) { + emit(Result.Success(resultingTestSuite)) + recordReport(report, resultingTestSuite.testCases) } + } - val generatedTestSuitePath: String = + private fun compileTestCases(testSuite: TestSuiteGeneratedByLLM) { + testSuite.testCases.forEachIndexed { index, testCase -> + val testCaseName = getClassWithTestCaseName(testCase.name) + val testCaseFilename = "$testCaseName${language.extension}" + val testCaseRepresentation = testsPresenter.representTestCase(testSuite, index) + val saveFilepath = testStorage.saveGeneratedTest( - generatedTestSuite.packageName, - testsPresenter.representTestSuite(generatedTestSuite), - resultPath, - testSuiteFilename, + packageString = testSuite.packageName, + code = testCaseRepresentation, + resultPath = resultPath, + testFileName = testCaseFilename, ) + testCase.isCompilable = compileTest(saveFilepath).isSuccessful() + } + } - // check that the file creation was successful - var allFilesCreated = true - for (path in generatedTestCasesPaths) { - allFilesCreated = allFilesCreated && File(path).exists() - } - if (!(allFilesCreated && File(generatedTestSuitePath).exists())) { - // either some test case file or the test suite file was not created - return Result.Failure(error = LlmError.FailedToSaveTestFiles) + private fun compileTest(filePath: String): ExecutionResult { + if (File(filePath).exists().not()) { + throw TestSavingFailureException() + } + + return testCompiler.compileCode( + path = File(filePath).absolutePath, + projectBuildPath = buildPath, + workingDir = resultPath, + ) + } + + private fun generatePromptMessage(error: TestSparkError) = + when (error) { + is LlmError.EmptyLlmResponse -> { + "You have provided an empty answer! Please, answer my previous question with the same formats" } - // Get test cases - val testCases: MutableList = - if (!isLastIteration(requestsCount)) { - generatedTestSuite.testCases + is LlmError.PromptTooLong -> { + if (promptSizeReductionStrategy.isReductionPossible()) { + promptSizeReductionStrategy.reduceSizeAndGeneratePrompt() } else { - compilableTestCases.toMutableList() + null } - - // Compile the test file - indicator.setText("Compilation tests checking") - - val testCasesCompilationResult = - testCompiler.compileTestCases(generatedTestCasesPaths, buildPath, testCases, resultPath) - val testSuiteCompilationResult = - testCompiler.compileCode(File(generatedTestSuitePath).absolutePath, buildPath, resultPath) - - // saving the compilable test cases - compilableTestCases.addAll(testCasesCompilationResult.compilableTestCases) - - // Process stopped checking - if (indicator.isCanceled()) throw ProcessCancelledException(module = TestSparkModule.Llm()) - - if (!testCasesCompilationResult.allTestCasesCompilable && !isLastIteration(requestsCount)) { - log.info { "Non-compilable test suite: \n${testsPresenter.representTestSuite(generatedTestSuite)}" } - - onWarningCallback?.invoke(LlmError.CompilationError) - - nextPromptMessage = - """ - I cannot compile the tests that you provided. The error is: - ``` - ${testSuiteCompilationResult.executionMessage} - ``` - Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text. - """.trimIndent() - log.info { nextPromptMessage } - continue } - log.info { "Result is compilable" } - - generatedTestSuite.imports.addAll(imports) - - generatedTestsArePassing = true + is LlmError.TestSuiteParsingError -> { + "The provided code is not parsable. Please, generate the correct code" + } - recordReport(report, testCases) + else -> null } - return Result.Success( - data = - FeedbackResponse( - generatedTestSuite = generatedTestSuite, - compilableTestCases = compilableTestCases, - ), - ) + private fun generateCompilationErrorPrompt(testSuite: TestSuiteGeneratedByLLM): String { + val generatedTestSuitePath: String = + testStorage.saveGeneratedTest( + testSuite.packageName, + testsPresenter.representTestSuite(testSuite), + resultPath, + testSuiteFilename, + ) + val testSuiteCompilationResult = compileTest(generatedTestSuitePath) + val prompt = + """ + I cannot compile the tests that you provided. The error is: + ``` + ${testSuiteCompilationResult.executionMessage} + ``` + Fix this issue in the provided tests.\nGenerate public classes and public methods. Response only a code with tests between ```, do not provide any other text. + """.trimIndent() + + return prompt } /** @@ -290,12 +194,23 @@ class LLMWithFeedbackCycle( */ private fun recordReport( report: Report, - testCases: MutableList, + testCases: List, ) { - for ((index, test) in testCases.withIndex()) { + testCases.forEachIndexed { index, test -> report.testCaseList[index] = TestCase(index, test.name, test.toString(), setOf()) } } - private fun isLastIteration(requestsCount: Int): Boolean = requestsCount > requestsCountThreshold + private companion object { + fun joinTestSuites(testSuites: List): TestSuiteGeneratedByLLM? { + if (testSuites.isEmpty()) return null + return TestSuiteGeneratedByLLM( + testCases = testSuites.map { it.testCases }.flatten().toMutableList(), + imports = testSuites.map { it.imports }.flatten().toMutableSet(), + otherInfo = testSuites.joinToString(separator = "\n") { it.otherInfo }, + packageName = testSuites.last().packageName, + annotation = testSuites.last().annotation, + ) + } + } } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt index ccf740896..ea48edf71 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/Utils.kt @@ -1,10 +1,15 @@ package org.jetbrains.research.testspark.core.generation.llm +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.cancel +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.Result -import org.jetbrains.research.testspark.core.error.TestSparkError -import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager -import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor +import org.jetbrains.research.testspark.core.exception.ProcessCancelledException import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.SupportedLanguage import org.jetbrains.research.testspark.core.test.TestsAssembler @@ -34,6 +39,7 @@ fun getPackageFromTestSuiteCode( ?.get(1) ?.value .orEmpty() + SupportedLanguage.Java -> javaPackagePattern .find(testSuiteCode) @@ -92,19 +98,15 @@ fun getClassWithTestCaseName(testCaseName: String): String { * * @param testCase: The test that is requested to be modified * @param task: A string representing the requested task for test modification - * @param indicator: A progress indicator object that represents the indication of the test generation progress. * * @return instance of TestSuiteGeneratedByLLM if the generated test cases are parsable, otherwise null. */ -fun executeTestCaseModificationRequest( - language: SupportedLanguage, +suspend fun executeTestCaseModificationRequest( testCase: String, task: String, - indicator: CustomProgressIndicator, - requestManager: RequestManager, + chatSessionManager: ChatSessionManager, testsAssembler: TestsAssembler, - errorMonitor: ErrorMonitor = DefaultErrorMonitor(), -): Result { +): Result { // Update Token information val prompt = buildString { @@ -115,18 +117,65 @@ fun executeTestCaseModificationRequest( append(task) } - val packageName = getPackageFromTestSuiteCode(testCase, language) - - val response = - requestManager.request( - language, - prompt, - indicator, - packageName, - testsAssembler, + return chatSessionManager + .request( + prompt = prompt, isUserFeedback = true, - errorMonitor, - ) + ).collectChunks(testsAssembler) +} + +suspend fun Flow>.collectChunks(testsAssembler: TestsAssembler): Result { + var failureResponse: Result? = null + + collect { result -> + when (result) { + is Result.Success -> testsAssembler.consume(result.data) + is Result.Failure -> { + failureResponse = result + return@collect + } + } + } - return response + val testSuite = testsAssembler.assembleTestSuite() + + return if (testSuite.isSuccess()) { + testSuite + } else { + failureResponse ?: Result.Failure(LlmError.EmptyLlmResponse) + } } + +/** + * Executes the provided [action], while constantly monitoring [indicator] cancellation in parallel. + * If the indicator is canceled, it stops the execution, cancels the coroutine, and returns null. + * + * @param indicator an instance of [CustomProgressIndicator] + * @param indicatorObservingIntervalMs specifies an interval between indicator cancellation checks + * @param action block of code to be executed + */ +fun runBlockingWithIndicatorLifecycle( + indicator: CustomProgressIndicator, + indicatorObservingIntervalMs: Long = 500, + action: suspend () -> T, +): T = + try { + runBlocking { + coroutineScope { + val indicatorObserver = + launch { + while (true) { + if (indicator.isCanceled()) { + this@coroutineScope.cancel() + } + delay(indicatorObservingIntervalMs) + } + } + val result = action() + indicatorObserver.cancel() + result + } + } + } catch (e: CancellationException) { + throw ProcessCancelledException(cause = e) + } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/HttpRequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/HttpRequestManager.kt new file mode 100644 index 000000000..4c26fa867 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/HttpRequestManager.kt @@ -0,0 +1,95 @@ +package org.jetbrains.research.testspark.core.generation.llm.network + +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.logging.DEFAULT +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logger +import io.ktor.client.plugins.logging.Logging +import io.ktor.client.request.bearerAuth +import io.ktor.client.request.preparePost +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.contentType +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.readUTF8Line +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.isActive +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.Json +import org.jetbrains.research.testspark.core.data.ChatMessage +import org.jetbrains.research.testspark.core.error.HttpError +import org.jetbrains.research.testspark.core.error.Result +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams +import java.net.HttpURLConnection + +class HttpRequestManager( + private val llmProvider: LlmProvider, + private val client: HttpClient = DEFAULT_HTTP_CLIENT, + private val json: Json = DEFAULT_JSON, +) : RequestManager { + override suspend fun sendRequest( + params: LlmParams, + chatHistory: List, + ): Flow> = + flow { + client + .preparePost(llmProvider.url(params)) { + contentType(ContentType.Application.Json) + if (llmProvider.supportsBearerAuth) bearerAuth(params.token) + setBody(llmProvider.constructJsonBody(json, params, chatHistory)) + }.execute { httpResponse -> + val responseIsSuccessful = httpResponse.status.value == HttpURLConnection.HTTP_OK + if (responseIsSuccessful) { + val channel = httpResponse.body() + while (currentCoroutineContext().isActive && !channel.isClosedForRead) { + val chunk = channel.readUTF8Line() ?: continue + val response = processChunk(chunk) ?: continue + emit(Result.Success(response)) + } + } else { + var error = llmProvider.mapHttpStatusCodeToError(httpResponse.status.value) + if (error is HttpError) error = error.copy(message = httpResponse.status.description) + emit(Result.Failure(error)) + } + } + } + + fun processChunk(chunk: String): String? { + if (chunk.startsWith(STREAMING_PREFIX).not()) return null + return llmProvider.decodeResponse(json, chunk.removePrefix(STREAMING_PREFIX)).extractContent() + } + + private companion object { + const val STREAMING_PREFIX = "data:" + + val DEFAULT_HTTP_CLIENT = + HttpClient(CIO) { + install(Logging) { + logger = Logger.DEFAULT + level = LogLevel.BODY + sanitizeHeader { header -> header == HttpHeaders.Authorization } + } + install(HttpTimeout) { + requestTimeoutMillis = 90_000 + connectTimeoutMillis = 15_000 + socketTimeoutMillis = 60_000 + } + } + + val DEFAULT_JSON = + Json { + ignoreUnknownKeys = true + encodeDefaults = true + + // this api is not error-prone; it just means it can be changed or removed in the future + @OptIn(ExperimentalSerializationApi::class) + explicitNulls = false + } + } +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/LlmProvider.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/LlmProvider.kt new file mode 100644 index 000000000..10fc71e61 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/LlmProvider.kt @@ -0,0 +1,72 @@ +package org.jetbrains.research.testspark.core.generation.llm.network + +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import org.jetbrains.research.testspark.core.data.ChatMessage +import org.jetbrains.research.testspark.core.data.LlmModuleType +import org.jetbrains.research.testspark.core.data.TestSparkModule +import org.jetbrains.research.testspark.core.error.HttpError +import org.jetbrains.research.testspark.core.error.LlmError +import org.jetbrains.research.testspark.core.error.TestSparkError +import org.jetbrains.research.testspark.core.generation.llm.network.model.GeminiResponse +import org.jetbrains.research.testspark.core.generation.llm.network.model.HuggingFaceResponse +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmResponse +import org.jetbrains.research.testspark.core.generation.llm.network.model.OpenAIResponse +import org.jetbrains.research.testspark.core.generation.llm.network.model.constructGeminiRequestBody +import org.jetbrains.research.testspark.core.generation.llm.network.model.constructHuggingFaceRequestBody +import org.jetbrains.research.testspark.core.generation.llm.network.model.constructOpenAiRequestBody +import java.net.HttpURLConnection + +enum class LlmProvider( + val url: (LlmParams) -> String, + val supportsBearerAuth: Boolean, + val constructJsonBody: Json.(LlmParams, List) -> String, + val decodeResponse: Json.(rawTextResponse: String) -> LlmResponse, + val mapHttpStatusCodeToError: (Int) -> TestSparkError, +) { + OpenAI( + url = { "https://api.openai.com/v1/chat/completions" }, + supportsBearerAuth = true, + constructJsonBody = { params, chatHistory -> + encodeToString(constructOpenAiRequestBody(params, chatHistory)) + }, + decodeResponse = { decodeFromString(it) }, + mapHttpStatusCodeToError = { httpCode -> + when (httpCode) { + HttpURLConnection.HTTP_BAD_REQUEST -> LlmError.PromptTooLong + else -> HttpError(httpCode = httpCode, module = TestSparkModule.Llm(LlmModuleType.OpenAi)) + } + }, + ), + + Gemini( + url = { + val baseUrl = "https://generativelanguage.googleapis.com/v1beta/models/" + "$baseUrl${it.model}:streamGenerateContent?alt=sse&key=${it.token}" + }, + supportsBearerAuth = false, + constructJsonBody = { params, chatHistory -> + encodeToString(constructGeminiRequestBody(params, chatHistory)) + }, + decodeResponse = { decodeFromString(it) }, + mapHttpStatusCodeToError = { httpCode -> + when (httpCode) { + HttpURLConnection.HTTP_INTERNAL_ERROR -> LlmError.PromptTooLong + else -> HttpError(httpCode = httpCode, module = TestSparkModule.Llm(LlmModuleType.Gemini)) + } + }, + ), + + Llama( + url = { "https://api-inference.huggingface.co/models/meta-llama/${it.model}" }, + supportsBearerAuth = true, + constructJsonBody = { params, chatHistory -> + encodeToString(constructHuggingFaceRequestBody(params, chatHistory)) + }, + decodeResponse = { decodeFromString(it) }, + mapHttpStatusCodeToError = { httpCode -> + HttpError(httpCode = httpCode, module = TestSparkModule.Llm(LlmModuleType.HuggingFace)) + }, + ), +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt index 381642247..b0d9a54c6 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/RequestManager.kt @@ -1,95 +1,13 @@ package org.jetbrains.research.testspark.core.generation.llm.network -import io.github.oshai.kotlinlogging.KotlinLogging -import org.jetbrains.research.testspark.core.data.ChatAssistantMessage +import kotlinx.coroutines.flow.Flow import org.jetbrains.research.testspark.core.data.ChatMessage -import org.jetbrains.research.testspark.core.data.ChatUserMessage -import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.Result -import org.jetbrains.research.testspark.core.error.TestSparkError -import org.jetbrains.research.testspark.core.monitor.DefaultErrorMonitor -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.SupportedLanguage -import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams -abstract class RequestManager( - var token: String, - val llmModel: String, -) { - val chatHistory = mutableListOf() - - protected val log = KotlinLogging.logger {} - - /** - * Sends a request to LLM with the given prompt and returns the generated TestSuite. - * - * @param prompt the prompt to send to LLM - * @param indicator the progress indicator to show progress during the request - * @param packageName the name of the package for the generated TestSuite - * @param isUserFeedback indicates if this request is a test generation request or a user feedback - * @return the generated TestSuite, or null and prompt message - */ - open fun request( - language: SupportedLanguage, - prompt: String, - indicator: CustomProgressIndicator, - packageName: String, - testsAssembler: TestsAssembler, - isUserFeedback: Boolean = false, - errorMonitor: ErrorMonitor = DefaultErrorMonitor(), - ): Result { - // save the prompt in chat history - chatHistory.add(ChatUserMessage(prompt)) - - // Send Request to LLM - log.info { "Sending Request..." } - - val sendResult = send(prompt, indicator, testsAssembler, errorMonitor) - - if (sendResult is Result.Failure) return sendResult - - // we remove the user request because we don't store user's requests in chat history - if (isUserFeedback) { - chatHistory.removeLast() - } - - return processResponse(testsAssembler, packageName, language, isUserFeedback) - } - - open fun processResponse( - testsAssembler: TestsAssembler, - packageName: String, - language: SupportedLanguage, - isUserFeedback: Boolean, - ): Result { - // save the full response in the chat history - val response = testsAssembler.getContent() - - log.info { "The full response: \n $response" } - if (!isUserFeedback) { - chatHistory.add(ChatAssistantMessage(response)) - } - - // check if the response is empty - if (response.isEmpty() || response.isBlank()) { - return Result.Failure(error = LlmError.EmptyLlmResponse) - } - - val testSuiteGeneratedByLLM = testsAssembler.assembleTestSuite() - - return if (testSuiteGeneratedByLLM == null) { - Result.Failure(error = LlmError.TestSuiteParsingError) - } else { - Result.Success(data = testSuiteGeneratedByLLM.reformat()) - } - } - - abstract fun send( - prompt: String, - indicator: CustomProgressIndicator, - testsAssembler: TestsAssembler, - errorMonitor: ErrorMonitor = DefaultErrorMonitor(), - ): Result +interface RequestManager { + suspend fun sendRequest( + params: LlmParams, + chatHistory: List, + ): Flow> } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/GeminiModel.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/GeminiModel.kt new file mode 100644 index 000000000..787e51ece --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/GeminiModel.kt @@ -0,0 +1,90 @@ +package org.jetbrains.research.testspark.core.generation.llm.network.model + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.jetbrains.research.testspark.core.data.ChatMessage + +@Serializable +data class GeminiRequest( + val contents: List, + val generationConfig: GeminiGenerationConfig?, + @SerialName("system_instruction") + val systemInstruction: GeminiSystemInstruction?, +) : LlmRequest + +@Serializable +data class GeminiResponse( + val candidates: List, +) : LlmResponse { + override fun extractContent(): String = + candidates + .first() + .content.parts + .first() + .text +} + +@Serializable +data class GeminiRequestContents( + val role: String, + val parts: List, +) + +@Serializable +data class GeminiTextObject( + val text: String, +) + +@Serializable +data class GeminiSystemInstruction( + val parts: List, +) + +@Serializable +data class GeminiGenerationConfig( + val temperature: Float?, + @SerialName("top_p") + val topP: Float?, +) + +@Serializable +data class GeminiCandidate( + val content: GeminiReplyContent, +) + +@Serializable +data class GeminiReplyContent( + val parts: List, + val role: String?, +) + +@Serializable +data class GeminiReplyPart( + val text: String, +) + +internal fun constructGeminiRequestBody( + params: LlmParams, + messages: List, +) = GeminiRequest( + contents = + messages.map { + GeminiRequestContents( + role = + when (it.role) { + ChatMessage.ChatRole.User -> "user" + ChatMessage.ChatRole.Assistant -> "model" + }, + parts = listOf(GeminiTextObject(it.content)), + ) + }, + systemInstruction = + params.systemPrompt?.let { + GeminiSystemInstruction(parts = listOf(GeminiTextObject(it))) + }, + generationConfig = + GeminiGenerationConfig( + temperature = params.temperature, + topP = params.topProbability, + ), +) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/HuggingFaceModel.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/HuggingFaceModel.kt new file mode 100644 index 000000000..68df06549 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/HuggingFaceModel.kt @@ -0,0 +1,56 @@ +package org.jetbrains.research.testspark.core.generation.llm.network.model + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.jetbrains.research.testspark.core.data.ChatMessage + +@Serializable +data class HuggingFaceRequest( + val inputs: String, + val parameters: HuggingFaceParameters, + val stream: Boolean = true, +) : LlmRequest + +@Serializable +data class HuggingFaceResponse( + val token: HuggingFaceToken, +) : LlmResponse { + override fun extractContent(): String = token.text +} + +@Serializable +data class HuggingFaceParameters( + val temperature: Float?, + @SerialName("top_p") + val topProbability: Float?, + @SerialName("return_full_text") + val appendPromptToResponse: Boolean = false, + @SerialName("min_length") + val minLength: Int = 4096, + @SerialName("max_length") + val maxLength: Int = 8192, + @SerialName("max_new_tokens") + val maxNewTokens: Int = 250, + @SerialName("max_time") + val maxTime: Float = 120.0F, +) + +@Serializable +data class HuggingFaceToken( + val text: String, +) + +internal fun constructHuggingFaceRequestBody( + params: LlmParams, + messages: List, +): HuggingFaceRequest { + val systemPrompt = params.systemPrompt?.let { "$it\n" } ?: "" + return HuggingFaceRequest( + inputs = systemPrompt + messages.joinToString(separator = "\n") { it.content }, + parameters = + HuggingFaceParameters( + temperature = params.temperature, + topProbability = params.topProbability, + ), + ) +} diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/LlmCommonModel.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/LlmCommonModel.kt new file mode 100644 index 000000000..df493c4f2 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/LlmCommonModel.kt @@ -0,0 +1,15 @@ +package org.jetbrains.research.testspark.core.generation.llm.network.model + +interface LlmResponse { + fun extractContent(): String +} + +interface LlmRequest + +data class LlmParams( + val model: String, + val token: String, + val systemPrompt: String? = null, + val temperature: Float? = null, + val topProbability: Float? = null, +) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/OpenAiModel.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/OpenAiModel.kt new file mode 100644 index 000000000..bf66c0500 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/generation/llm/network/model/OpenAiModel.kt @@ -0,0 +1,60 @@ +package org.jetbrains.research.testspark.core.generation.llm.network.model + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.jetbrains.research.testspark.core.data.ChatMessage + +@Serializable +data class OpenAIRequest( + val model: String, + val stream: Boolean = true, + val messages: List, + val temperature: Float?, + @SerialName("top_p") + val topProbability: Float? = null, +) : LlmRequest + +@Serializable +data class OpenAIResponse( + val choices: List, +) : LlmResponse { + override fun extractContent(): String = choices.first().delta.content +} + +@Serializable +data class OpenAIChatMessage( + val role: String, + val content: String, +) + +@Serializable +data class OpenAIChoice( + val index: Int, + val delta: Delta, + @SerialName("finish_reason") + val finishedReason: String, +) + +@Serializable +data class Delta( + val role: String?, + val content: String, +) + +internal fun constructOpenAiRequestBody( + params: LlmParams, + messages: List, +) = OpenAIRequest( + model = params.model, + messages = + messages.map { message -> + val role = + when (message.role) { + ChatMessage.ChatRole.User -> "user" + ChatMessage.ChatRole.Assistant -> "assistant" + } + OpenAIChatMessage(role, message.content) + }, + temperature = params.temperature, + topProbability = params.topProbability, +) diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/JUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/JUnitTestSuiteParser.kt index ff0a33038..c1807cc81 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/JUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/JUnitTestSuiteParser.kt @@ -36,5 +36,5 @@ interface JUnitTestSuiteParser { * @param rawText The raw text provided by the LLM that contains the generated test cases. * @return A GeneratedTestSuite instance containing the extracted test cases. */ - fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? + fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt index f13d40c7f..f14e10d41 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/SupportedLanguage.kt @@ -5,7 +5,8 @@ package org.jetbrains.research.testspark.core.test */ enum class SupportedLanguage( val languageId: String, + val extension: String, ) { - Java("JAVA"), - Kotlin("kotlin"), + Java(languageId = "JAVA", extension = ".java"), + Kotlin(languageId = "kotlin", extension = ".kt"), } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt index 86a2cf4f8..b13839753 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestCompiler.kt @@ -1,13 +1,7 @@ package org.jetbrains.research.testspark.core.test -import org.jetbrains.research.testspark.core.test.data.TestCaseGeneratedByLLM import org.jetbrains.research.testspark.core.utils.DataFilesUtil -data class TestCasesCompilationResult( - val allTestCasesCompilable: Boolean, - val compilableTestCases: MutableSet, -) - data class ExecutionResult( val exitCode: Int, val executionMessage: String, @@ -24,35 +18,6 @@ abstract class TestCompiler( val junitPath = junitLibPaths.joinToString(separator.toString()) val commonPath = "$junitPath${separator}$dependencyLibPath$separator" - /** - * Compiles a list of test cases and returns the compilation result. - * - * @param generatedTestCasesPaths A list of file paths where the generated test cases are located. - * @param buildPath All the directories where the compiled code of the project under test is saved. This path is used as a classpath to run each test case. - * @param testCases A mutable list of `TestCaseGeneratedByLLM` objects representing the test cases to be compiled. - * @param workingDir The path of the directory that contains package directories of the code to compile - * @return A `TestCasesCompilationResult` object containing the overall compilation success status and a set of compilable test cases. - */ - fun compileTestCases( - generatedTestCasesPaths: List, - buildPath: String, - testCases: MutableList, - workingDir: String, - ): TestCasesCompilationResult { - var allTestCasesCompilable = true - val compilableTestCases: MutableSet = mutableSetOf() - - for (index in generatedTestCasesPaths.indices) { - val compilable = compileCode(generatedTestCasesPaths[index], buildPath, workingDir).isSuccessful() - allTestCasesCompilable = allTestCasesCompilable && compilable - if (compilable) { - compilableTestCases.add(testCases[index]) - } - } - - return TestCasesCompilationResult(allTestCasesCompilable, compilableTestCases) - } - /** * Compiles the code at the specified path using the provided project build path. * diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt index 64006f117..d25f78232 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/TestsAssembler.kt @@ -1,5 +1,6 @@ package org.jetbrains.research.testspark.core.test +import org.jetbrains.research.testspark.core.error.Result import org.jetbrains.research.testspark.core.test.data.TestSuiteGeneratedByLLM abstract class TestsAssembler { @@ -34,5 +35,5 @@ abstract class TestsAssembler { * * @return A TestSuiteGeneratedByLLM object containing information about the extracted test cases. */ - abstract fun assembleTestSuite(): TestSuiteGeneratedByLLM? + abstract fun assembleTestSuite(): Result } diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt index 8c99130d1..5305a1037 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestCaseGeneratedByLLM.kt @@ -14,6 +14,7 @@ data class TestCaseGeneratedByLLM( var throwsException: String = "", var lines: MutableList = mutableListOf(), val printTestBodyStrategy: TestBodyPrinter, + var isCompilable: Boolean = true, ) { /** * Compares this object to the specified object for equality. diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt index 386bb842d..cb1da3db8 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt @@ -14,7 +14,7 @@ class JavaJUnitTestSuiteParser( private val junitVersion: JUnitVersion, private val testBodyPrinter: TestBodyPrinter, ) : JUnitTestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM { val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Java) if (packageInsideTestText.isNotBlank()) { packageName = packageInsideTestText diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt index 94dee63fb..102be4933 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt @@ -14,7 +14,7 @@ class KotlinJUnitTestSuiteParser( private val junitVersion: JUnitVersion, private val testBodyPrinter: TestBodyPrinter, ) : JUnitTestSuiteParser { - override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM? { + override fun parseTestSuite(rawText: String): TestSuiteGeneratedByLLM { val packageInsideTestText = getPackageFromTestSuiteCode(rawText, SupportedLanguage.Kotlin) if (packageInsideTestText.isNotBlank()) { packageName = packageInsideTestText diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt index a1b96f967..03b5125d0 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt @@ -25,72 +25,74 @@ class JUnitTestSuiteParserStrategy { packageName: String, testNamePattern: String, printTestBodyStrategy: TestBodyPrinter, - ): TestSuiteGeneratedByLLM? { - if (rawText.isBlank()) { - return null - } - - try { - val rawCode = if (rawText.contains("```")) rawText.split("```")[1] else rawText - - // save imports - val imports = - importPattern - .findAll(rawCode) - .map { it.groupValues[0] } - .toMutableSet() - - // save ExtendWith or RunWith annotation if present - val runWithAnnotation: String = JUnitVersion.JUnit4.runWithAnnotationMeta.extract(rawCode) ?: "" - val annotation = JUnitVersion.JUnit5.runWithAnnotationMeta.extract(rawCode) ?: runWithAnnotation - - val testSet: MutableList = rawCode.split("@Test").toMutableList() - - // save annotations and pre-set methods - val otherInfo: String = - run { - val otherInfoList = testSet.removeAt(0).split("{").toMutableList() - otherInfoList.removeFirst() - val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" - otherInfo.ifBlank { "" } - } + ): TestSuiteGeneratedByLLM { + require(rawText.isNotBlank()) { "The raw text is blank" } + + val rawCode = + if (rawText.contains("```")) { + rawText + .replace(oldValue = "```", newValue = "") + .trimIndent() + .trim() + } else { + rawText + } - // Save the main test cases - val testCases: MutableList = mutableListOf() - val testCaseParser = JUnitTestCaseParser() + // save imports + val imports = + importPattern + .findAll(rawCode) + .map { it.groupValues[0] } + .toMutableSet() + + // save ExtendWith or RunWith annotation if present + val runWithAnnotation: String = JUnitVersion.JUnit4.runWithAnnotationMeta.extract(rawCode) ?: "" + val annotation = JUnitVersion.JUnit5.runWithAnnotationMeta.extract(rawCode) ?: runWithAnnotation + + val testSet: MutableList = rawCode.split("@Test").toMutableList() + + // save annotations and pre-set methods + val otherInfo: String = + run { + val otherInfoList = testSet.removeAt(0).split("{").toMutableList() + otherInfoList.removeFirst() + val otherInfo = otherInfoList.joinToString("{").trimEnd() + "\n\n" + otherInfo.ifBlank { "" } + } - testSet.forEach ca@{ - val rawTest = "@Test$it" + // Save the main test cases + val testCases: MutableList = mutableListOf() + val testCaseParser = JUnitTestCaseParser() - val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) - val result: Result = - testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern, printTestBodyStrategy) + testSet.forEach ca@{ + val rawTest = "@Test$it" - if (result.isFailure()) { - return@ca - } + val isLastTestCaseInTestSuite = (testCases.size == testSet.size - 1) + val result: Result = + testCaseParser.parse(rawTest, isLastTestCaseInTestSuite, testNamePattern, printTestBodyStrategy) - val currentTest = (result as Result.Success).data + if (result.isFailure()) { + return@ca + } - // TODO: make logging work - // log.info("New test case: $currentTest") + val currentTest = (result as Result.Success).data - testCases.add(currentTest) - } + // TODO: make logging work + // log.info("New test case: $currentTest") - val testSuite = - TestSuiteGeneratedByLLM( - imports = imports, - packageName = packageName, - annotation = annotation, - otherInfo = otherInfo, - testCases = testCases, - ) - - return testSuite - } catch (e: Exception) { - return null + testCases.add(currentTest) } + + val testSuite = + TestSuiteGeneratedByLLM( + imports = imports, + packageName = packageName, + annotation = annotation, + otherInfo = otherInfo, + testCases = testCases, + ) + + return testSuite } } @@ -100,7 +102,7 @@ class JUnitTestSuiteParserStrategy { isLastTestCaseInTestSuite: Boolean, testNamePattern: String, printTestBodyStrategy: TestBodyPrinter, - ): Result { + ): Result { var expectedException = "" var throwsException = "" val testLines: MutableList = mutableListOf() diff --git a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt index 3a65ca112..b5257c289 100644 --- a/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt +++ b/core/src/test/kotlin/org/jetbrains/research/testspark/core/test/parsers/kotlin/KotlinJUnitTestSuiteParserTest.kt @@ -260,7 +260,33 @@ class KotlinJUnitTestSuiteParserTest { assertNotNull(testSuite1) assertNotNull(testSuite2) - assertEquals("org.pkg1", testSuite1!!.packageName) - assertEquals("org.pkg2", testSuite2!!.packageName) + assertEquals("org.pkg1", testSuite1.packageName) + assertEquals("org.pkg2", testSuite2.packageName) + } + + @Test + fun testParseWithRandomBackticksAtBeginning() { + val text = + """ + ``` + ```kotlin + import org.junit.jupiter.api.Test + + class RandomBackticksTestClass { + @Test + fun testWithRandomBackticks() { + // Test case implementation + } + } + ``` + """.trimIndent() + + val testBodyPrinter = KotlinTestBodyPrinter() + val parser = KotlinJUnitTestSuiteParser("org.example", JUnitVersion.JUnit5, testBodyPrinter) + val testSuite: TestSuiteGeneratedByLLM? = parser.parseTestSuite(text) + + assertNotNull(testSuite) + assertEquals(1, testSuite!!.testCases.size) + assertEquals("testWithRandomBackticks", testSuite.testCases[0].name) } } diff --git a/gradle.properties b/gradle.properties index 8d71579f8..d3cca3e09 100644 --- a/gradle.properties +++ b/gradle.properties @@ -33,4 +33,6 @@ gradleVersion = 8.10.2 # suppress inspection "UnusedProperty" kotlin.stdlib.default.dependency = false -jvmToolchainVersion = 17 \ No newline at end of file +jvmToolchainVersion = 17 + +kotlin.daemon.jvmargs=-Xmx2g \ No newline at end of file diff --git a/src/hasGrazieAccess/kotlin/org/jetbrains/research/grazie/Request.kt b/src/hasGrazieAccess/kotlin/org/jetbrains/research/grazie/Request.kt index 6100e580f..f75485ae6 100644 --- a/src/hasGrazieAccess/kotlin/org/jetbrains/research/grazie/Request.kt +++ b/src/hasGrazieAccess/kotlin/org/jetbrains/research/grazie/Request.kt @@ -1,29 +1,16 @@ package org.jetbrains.research.grazie -import kotlinx.coroutines.flow.catch -import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.flow.Flow import org.jetbrains.research.testSpark.grazie.TestGeneration -import org.jetbrains.research.testspark.core.test.TestsAssembler import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequest class Request : GrazieRequest { - override fun request( + override suspend fun request( token: String, messages: List>, profile: String, - testsAssembler: TestsAssembler, - ): String { + ): Flow { val generation = TestGeneration(token) - var errorMessage = "" - runBlocking { - generation - .generate(messages, profile) - .catch { - errorMessage = it.message.toString() - }.collect { - testsAssembler.consume(it) - } - } - return errorMessage + return generation.generate(messages, profile) } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt b/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt index 12b900156..c0ec4a99c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/data/UIContext.kt @@ -2,13 +2,13 @@ package org.jetbrains.research.testspark.data import org.jetbrains.research.testspark.actions.controllers.IndicatorController import org.jetbrains.research.testspark.core.data.TestGenerationData -import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager +import org.jetbrains.research.testspark.core.generation.llm.ChatSessionManager import org.jetbrains.research.testspark.core.monitor.ErrorMonitor data class UIContext( val projectContext: ProjectContext, val testGenerationOutput: TestGenerationData, - var requestManager: RequestManager? = null, + var chatSessionManager: ChatSessionManager, val indicatorController: IndicatorController, val errorMonitor: ErrorMonitor, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt index e1a3f7dde..47737d315 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/generatedTests/TestCasePanelBuilder.kt @@ -82,7 +82,6 @@ import javax.swing.ScrollPaneConstants import javax.swing.SwingUtilities import javax.swing.border.Border import javax.swing.border.MatteBorder -import kotlin.collections.HashMap class TestCasePanelBuilder( private val project: Project, @@ -530,14 +529,13 @@ class TestCasePanelBuilder( val testModificationResult = LLMHelper.testModificationRequest( - language, - initialCodes[currentRequestNumber - 1], - requestComboBox.editor.item.toString(), - ijIndicator, - uiContext.requestManager!!, - project, - uiContext.testGenerationOutput, - uiContext.errorMonitor, + language = language, + testCase = initialCodes[currentRequestNumber - 1], + task = requestComboBox.editor.item.toString(), + indicator = ijIndicator, + chatSessionManager = uiContext.chatSessionManager, + project = project, + testGenerationOutput = uiContext.testGenerationOutput, ) when (testModificationResult) { @@ -551,7 +549,10 @@ class TestCasePanelBuilder( is Result.Success -> { if (testModificationResult.data.isEmpty()) { - LLMErrorManager().warningProcess(LLMMessagesBundle.get("modifyWithLLMError"), project) + LLMErrorManager().warningProcess( + LLMMessagesBundle.get("modifyWithLLMError"), + project, + ) } else { testModificationResult.data.setTestFileName( getClassWithTestCaseName(testCase.testName), diff --git a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt index 3b425de57..101d54a03 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/helpers/LLMHelper.kt @@ -8,11 +8,10 @@ import com.intellij.util.io.HttpRequests import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.llm.LLMSettingsBundle import org.jetbrains.research.testspark.core.data.TestGenerationData -import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.Result -import org.jetbrains.research.testspark.core.error.TestSparkError +import org.jetbrains.research.testspark.core.generation.llm.ChatSessionManager import org.jetbrains.research.testspark.core.generation.llm.executeTestCaseModificationRequest -import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager +import org.jetbrains.research.testspark.core.generation.llm.runBlockingWithIndicatorLifecycle import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.JUnitTestSuiteParser @@ -273,16 +272,10 @@ object LLMHelper { testCase: String, task: String, indicator: CustomProgressIndicator, - requestManager: RequestManager, + chatSessionManager: ChatSessionManager, project: Project, testGenerationOutput: TestGenerationData, - errorMonitor: ErrorMonitor, - ): Result { - // Update Token information - if (!updateToken(requestManager, project, errorMonitor)) { - return Result.Failure(error = LlmError.UnsetTokenError) - } - + ): Result { val jUnitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion val testBodyPrinter = TestBodyPrinter.create(language) val testSuiteParser = @@ -301,32 +294,17 @@ object LLMHelper { ) val testSuite = - executeTestCaseModificationRequest( - language, - testCase, - task, - indicator, - requestManager, - testsAssembler, - errorMonitor, - ) + runBlockingWithIndicatorLifecycle(indicator) { + executeTestCaseModificationRequest( + testCase, + task, + chatSessionManager, + testsAssembler, + ) + } return testSuite } - /** - * Updates token based on the last entries of settings and check if the token is valid - * - * @return True if the token is set, false otherwise. - */ - private fun updateToken( - requestManager: RequestManager, - project: Project, - errorMonitor: ErrorMonitor, - ): Boolean { - requestManager.token = LlmSettingsArguments(project).getToken() - return LlmSettingsArguments(project).isTokenSet() - } - /** * Retrieves a list of available models from the OpenAI API. * @@ -444,5 +422,7 @@ object LLMHelper { * * @return an array of string representing the available HuggingFace models */ - private fun getHuggingFaceModels(): Array = arrayOf("Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct") + private fun getHuggingFaceModels(): Array = LlamaModels + + val LlamaModels = arrayOf("Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct") } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/settings/plugin/PluginSettingsConfigurable.kt b/src/main/kotlin/org/jetbrains/research/testspark/settings/plugin/PluginSettingsConfigurable.kt index 8271929f8..a73a6ab69 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/settings/plugin/PluginSettingsConfigurable.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/settings/plugin/PluginSettingsConfigurable.kt @@ -1,6 +1,5 @@ package org.jetbrains.research.testspark.settings.plugin -import com.intellij.openapi.components.service import com.intellij.openapi.project.Project import org.jetbrains.research.testspark.services.PluginSettingsService import org.jetbrains.research.testspark.settings.template.SettingsConfigurable @@ -30,7 +29,7 @@ class PluginSettingsConfigurable( * Sets the stored state values to the corresponding UI components. This method is called immediately after `createComponent` method. */ override fun reset() { - val settingsState: PluginSettingsState = project.service().state + val settingsState: PluginSettingsState = project.getService(PluginSettingsService::class.java).state settingsComponent!!.showCoverageCheckboxSelected = settingsState.showCoverageCheckboxSelected settingsComponent!!.buildPath = settingsState.buildPath settingsComponent!!.colorRed = settingsState.colorRed @@ -45,7 +44,7 @@ class PluginSettingsConfigurable( * @return whether any setting has been modified */ override fun isModified(): Boolean { - val settingsState: PluginSettingsState = project.service().state + val settingsState: PluginSettingsState = project.getService(PluginSettingsService::class.java).state var modified: Boolean = settingsComponent!!.buildPath != settingsState.buildPath modified = modified or (settingsComponent!!.showCoverageCheckboxSelected != settingsState.showCoverageCheckboxSelected) modified = modified or (settingsComponent!!.colorRed != settingsState.colorRed) @@ -59,7 +58,7 @@ class PluginSettingsConfigurable( * Persists the modified state after a user hit Apply button. */ override fun apply() { - val settingsState: PluginSettingsState = project.service().state + val settingsState: PluginSettingsState = project.getService(PluginSettingsService::class.java).state settingsState.showCoverageCheckboxSelected = settingsComponent!!.showCoverageCheckboxSelected settingsState.colorRed = settingsComponent!!.colorRed settingsState.colorGreen = settingsComponent!!.colorGreen diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/error/ErrorUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/error/ErrorUtils.kt index 273b26db1..3bca58e0d 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/error/ErrorUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/error/ErrorUtils.kt @@ -14,6 +14,7 @@ import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.TestSparkError import org.jetbrains.research.testspark.core.exception.CommonException import org.jetbrains.research.testspark.core.exception.CompilerException +import org.jetbrains.research.testspark.core.exception.ProcessCancelledException import org.jetbrains.research.testspark.core.exception.TestSparkException import org.jetbrains.research.testspark.tools.error.message.commonExceptionMessage import org.jetbrains.research.testspark.tools.error.message.compilerExceptionMessage @@ -35,12 +36,16 @@ fun Project.createNotification( fun Project.createNotification( exception: TestSparkException, notificationType: NotificationType, -) = createNotification( - module = exception.module, - message = exception.displayMessage ?: PluginMessagesBundle.get("unknownErrorMessage"), - notificationType = notificationType, - logMessage = "Hello World exception $exception ${exception.displayMessage?.let { ": $it" } ?: ""}", -) +) { + // process cancellation exceptions are part of the normal execution flow + if (exception is ProcessCancelledException) return + return createNotification( + module = exception.module, + message = exception.displayMessage ?: PluginMessagesBundle.get("unknownErrorMessage"), + notificationType = notificationType, + logMessage = "$exception ${exception.displayMessage?.let { ": $it" } ?: ""}", + ) +} val TestSparkException.displayMessage: String? get() = @@ -67,7 +72,7 @@ fun Project.createNotification( logMessage: String, ) { val log = Logger.getInstance(this::class.java) - log.info("Error in $module module: $logMessage") + log.warn("Error in $module module: $logMessage") NotificationGroupManager .getInstance() diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/CompilerExceptionMessage.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/CompilerExceptionMessage.kt index c27e2d85d..df87766df 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/CompilerExceptionMessage.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/CompilerExceptionMessage.kt @@ -7,6 +7,7 @@ import org.jetbrains.research.testspark.core.exception.CompilerException import org.jetbrains.research.testspark.core.exception.JavaCompilerNotFoundException import org.jetbrains.research.testspark.core.exception.JavaSDKMissingException import org.jetbrains.research.testspark.core.exception.KotlinCompilerNotFoundException +import org.jetbrains.research.testspark.core.exception.TestSavingFailureException val CompilerException.compilerExceptionMessage: String? get() = @@ -18,4 +19,5 @@ val CompilerException.compilerExceptionMessage: String? PluginMessagesBundle.get("compilerNotFoundErrorMessage").format("Java", javaHomeDirectoryPath) is KotlinCompilerNotFoundException -> PluginMessagesBundle.get("compilerNotFoundErrorMessage").format("Kotlin", kotlinSdkHomeDirectory) + is TestSavingFailureException -> LLMMessagesBundle.get("savingTestFileIssue") } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/LlmErrorMessage.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/LlmErrorMessage.kt index 910381f5e..c2ea9a354 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/LlmErrorMessage.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/error/message/LlmErrorMessage.kt @@ -11,7 +11,6 @@ val LlmError.llmErrorDisplayMessage: String? is LlmError.PromptTooLong -> LLMMessagesBundle.get("tooLongPromptRequest") is LlmError.GrazieNotAvailable -> LLMMessagesBundle.get("grazieError") is LlmError.NoCompilableTestCasesGenerated -> LLMMessagesBundle.get("invalidLLMResult") - is LlmError.FailedToSaveTestFiles -> LLMMessagesBundle.get("savingTestFileIssue") is LlmError.CompilationError -> LLMMessagesBundle.get("compilationError") is LlmError.EmptyLlmResponse -> LLMMessagesBundle.get("emptyResponse") is LlmError.TestSuiteParsingError -> LLMMessagesBundle.get("emptyResponse") diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt index 4bf114de4..df7940043 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/evosuite/generation/EvoSuiteProcessManager.kt @@ -33,7 +33,7 @@ import org.jetbrains.research.testspark.tools.TestsExecutionResultManager import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.evosuite.EvoSuiteSettingsArguments import org.jetbrains.research.testspark.tools.evosuite.error.EvoSuiteErrorManager -import org.jetbrains.research.testspark.tools.llm.generation.StandardRequestManagerFactory +import org.jetbrains.research.testspark.tools.llm.generation.ChatSessionManagerFactory import org.jetbrains.research.testspark.tools.template.generation.ProcessManager import java.io.FileReader import java.nio.charset.Charset @@ -252,7 +252,7 @@ class EvoSuiteProcessManager( return UIContext( projectContext, generatedTestsData, - StandardRequestManagerFactory(project).getRequestManager(project), + ChatSessionManagerFactory.getChatSessionManager(project), indicatorController, errorMonitor, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/kex/generation/KexProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/kex/generation/KexProcessManager.kt index 01bd248c9..158ebdb60 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/kex/generation/KexProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/kex/generation/KexProcessManager.kt @@ -27,7 +27,7 @@ import org.jetbrains.research.testspark.tools.TestsExecutionResultManager import org.jetbrains.research.testspark.tools.ToolUtils import org.jetbrains.research.testspark.tools.kex.KexSettingsArguments import org.jetbrains.research.testspark.tools.kex.error.KexErrorManager -import org.jetbrains.research.testspark.tools.llm.generation.StandardRequestManagerFactory +import org.jetbrains.research.testspark.tools.llm.generation.ChatSessionManagerFactory import org.jetbrains.research.testspark.tools.template.generation.ProcessManager import java.io.File import java.io.IOException @@ -156,7 +156,7 @@ class KexProcessManager( return UIContext( projectContext, generatedTestsData, - StandardRequestManagerFactory(project).getRequestManager(project), + ChatSessionManagerFactory.getChatSessionManager(project), indicatorController, errorMonitor, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/ChatSessionManagerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/ChatSessionManagerFactory.kt new file mode 100644 index 000000000..7a7b4673a --- /dev/null +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/ChatSessionManagerFactory.kt @@ -0,0 +1,39 @@ +package org.jetbrains.research.testspark.tools.llm.generation + +import com.intellij.openapi.project.Project +import org.jetbrains.research.testspark.core.generation.llm.ChatSessionManager +import org.jetbrains.research.testspark.core.generation.llm.network.HttpRequestManager +import org.jetbrains.research.testspark.core.generation.llm.network.LlmProvider +import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams +import org.jetbrains.research.testspark.helpers.LLMHelper.LlamaModels +import org.jetbrains.research.testspark.settings.llm.LLMSettingsState.DefaultLLMSettingsState +import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments +import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager + +object ChatSessionManagerFactory { + fun getChatSessionManager(project: Project): ChatSessionManager = + ChatSessionManager( + requestManager = getRequestManager(project), + llmParams = + LlmParams( + model = LlmSettingsArguments(project).getModel(), + token = LlmSettingsArguments(project).getToken(), + ), + ) + + private fun getRequestManager(project: Project): RequestManager { + val model = LlmSettingsArguments(project).getModel() + return when (val platform = LlmSettingsArguments(project).currentLLMPlatformName()) { + DefaultLLMSettingsState.grazieName -> GrazieRequestManager() + DefaultLLMSettingsState.openAIName -> HttpRequestManager(LlmProvider.OpenAI) + DefaultLLMSettingsState.huggingFaceName -> + when { + model in LlamaModels -> HttpRequestManager(LlmProvider.Llama) + else -> throw IllegalStateException("Unsupported model: $model") + } + DefaultLLMSettingsState.geminiName -> HttpRequestManager(LlmProvider.Gemini) + else -> throw IllegalStateException("Unknown selected platform: $platform") + } + } +} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index b797731b0..b2a8d9ba7 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -4,6 +4,8 @@ import com.intellij.openapi.diagnostic.Logger import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.JUnitVersion import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.core.error.LlmError +import org.jetbrains.research.testspark.core.error.Result import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.JUnitTestSuiteParser import org.jetbrains.research.testspark.core.test.TestsAssembler @@ -48,23 +50,31 @@ class JUnitTestsAssembler( } } - override fun assembleTestSuite(): TestSuiteGeneratedByLLM? { - val testSuite = junitTestSuiteParser.parseTestSuite(super.getContent()) + override fun assembleTestSuite(): Result = + try { + val testSuite = junitTestSuiteParser.parseTestSuite(super.getContent()) - // save RunWith - if (testSuite?.annotation?.isNotBlank() == true) { - generationData.annotation = testSuite.annotation - generationData.importsCode.add(junitVersion.runWithAnnotationMeta.import) - } else { - generationData.annotation = "" - generationData.importsCode.remove(junitVersion.runWithAnnotationMeta.import) - } + // save RunWith + if (testSuite.annotation.isNotBlank()) { + generationData.annotation = testSuite.annotation + generationData.importsCode.add(junitVersion.runWithAnnotationMeta.import) + } else { + generationData.annotation = "" + generationData.importsCode.remove(junitVersion.runWithAnnotationMeta.import) + } - // save annotations and pre-set methods - generationData.otherInfo = testSuite?.otherInfo ?: "" + // save annotations and pre-set methods + generationData.otherInfo = testSuite.otherInfo - // logging generated test cases if any - testSuite?.testCases?.forEach { testCase -> log.info("Generated test case: $testCase") } - return testSuite - } + // logging generated test cases if any + testSuite.testCases.forEach { testCase -> log.info("Generated test case: $testCase") } + + if (testSuite.testCases.isEmpty()) { + Result.Failure(error = LlmError.EmptyLlmResponse) + } else { + Result.Success(testSuite) + } + } catch (e: Exception) { + Result.Failure(LlmError.TestSuiteParsingError(cause = e)) + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt index e9cde0585..e66e15777 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/LLMProcessManager.kt @@ -10,6 +10,7 @@ import org.jetbrains.research.testspark.bundles.llm.LLMMessagesBundle import org.jetbrains.research.testspark.bundles.plugin.PluginMessagesBundle import org.jetbrains.research.testspark.core.data.TestGenerationData import org.jetbrains.research.testspark.core.data.TestSparkModule +import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.Result import org.jetbrains.research.testspark.core.exception.JavaSDKMissingException import org.jetbrains.research.testspark.core.exception.ProcessCancelledException @@ -17,6 +18,7 @@ import org.jetbrains.research.testspark.core.generation.llm.LLMWithFeedbackCycle import org.jetbrains.research.testspark.core.generation.llm.getImportsCodeFromTestSuiteCode import org.jetbrains.research.testspark.core.generation.llm.getPackageFromTestSuiteCode import org.jetbrains.research.testspark.core.generation.llm.prompt.PromptSizeReductionStrategy +import org.jetbrains.research.testspark.core.generation.llm.runBlockingWithIndicatorLifecycle import org.jetbrains.research.testspark.core.monitor.ErrorMonitor import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator import org.jetbrains.research.testspark.core.test.JUnitTestSuiteParser @@ -126,8 +128,7 @@ class LLMProcessManager( val initialPromptMessage = promptManager.generatePrompt(codeType, testSamplesCode, generatedTestsData.polyDepthReducing) - // initiate a new RequestManager - val requestManager = StandardRequestManagerFactory(project).getRequestManager(project) + val chatSessionManager = ChatSessionManagerFactory.getChatSessionManager(project) // adapter for the existing prompt reduction functionality val promptSizeReductionStrategy = @@ -193,53 +194,29 @@ class LLMProcessManager( initialPromptMessage = initialPromptMessage, promptSizeReductionStrategy = promptSizeReductionStrategy, testSuiteFilename = testFileName, - packageName = packageName, resultPath = generatedTestsData.resultPath, buildPath = buildPath, - requestManager = requestManager, + chatSessionManager = chatSessionManager, testsAssembler = testsAssembler, testCompiler = testCompiler, testStorage = testProcessor, testsPresenter = testsPresenter, - indicator = indicator, requestsCountThreshold = maxRequests, - errorMonitor = errorMonitor, ) - val feedbackResponse = - llmFeedbackCycle.run { warning -> - project.createNotification(warning, NotificationType.WARNING) - } - - // Process stopped checking - if (ToolUtils.isProcessStopped(errorMonitor, indicator)) throw ProcessCancelledException(TestSparkModule.Llm()) - log.info("Feedback cycle finished execution with $feedbackResponse") - - when (feedbackResponse) { - is Result.Success -> { - log.info("Add ${feedbackResponse.data.compilableTestCases.size} compilable test cases into generatedTestsData") - } - - is Result.Failure -> { - project.createNotification(feedbackResponse.error, NotificationType.ERROR) - return null + val testSuite = runFeedbackCycle(indicator, llmFeedbackCycle) + val testSuiteRepresentation = + testSuite?.let { + log.info("Add ${it.testCases} compilable test cases into generatedTestsData") + val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) + testSuitePresenter.toString(testSuite = it) } - } if (ToolUtils.isProcessStopped(errorMonitor, indicator)) throw ProcessCancelledException(TestSparkModule.Llm()) - // Error during the collecting - if (errorMonitor.hasErrorOccurred()) throw ProcessCancelledException(TestSparkModule.Llm()) - - log.info("Save generated test suite and test cases into the project workspace") - - val testSuitePresenter = JUnitTestSuitePresenter(project, generatedTestsData, language) - val generatedTestSuite: TestSuiteGeneratedByLLM? = feedbackResponse.getDataOrNull()?.generatedTestSuite - val testSuiteRepresentation = - if (generatedTestSuite != null) testSuitePresenter.toString(generatedTestSuite) else null - ToolUtils.transferToIJTestCases(report) + log.info("Save generated test suite and test cases into the project workspace") ToolUtils.saveData( project, report, @@ -251,6 +228,36 @@ class LLMProcessManager( language, ) - return UIContext(projectContext, generatedTestsData, requestManager, indicatorController, errorMonitor) + return UIContext(projectContext, generatedTestsData, chatSessionManager, indicatorController, errorMonitor) } + + private fun runFeedbackCycle( + indicator: CustomProgressIndicator, + llmFeedbackCycle: LLMWithFeedbackCycle, + ): TestSuiteGeneratedByLLM? = + runBlockingWithIndicatorLifecycle(indicator) { + var feedbackResponse: TestSuiteGeneratedByLLM? = null + llmFeedbackCycle.run().collect { result -> + when (result) { + is Result.Success -> { + feedbackResponse = result.data + } + + is Result.Failure -> { + when (result.error) { + is LlmError.EmptyLlmResponse, + is LlmError.TestSuiteParsingError, + is LlmError.CompilationError, + -> { + project.createNotification(result.error, NotificationType.WARNING) + } + + else -> project.createNotification(result.error, NotificationType.ERROR) + } + } + } + } + log.info("Feedback cycle finished execution with result: $feedbackResponse") + feedbackResponse + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt deleted file mode 100644 index 31abb2d3a..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/RequestManagerFactory.kt +++ /dev/null @@ -1,31 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation - -import com.intellij.openapi.project.Project -import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager -import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState -import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments -import org.jetbrains.research.testspark.tools.llm.generation.gemini.GeminiRequestManager -import org.jetbrains.research.testspark.tools.llm.generation.grazie.GrazieRequestManager -import org.jetbrains.research.testspark.tools.llm.generation.hf.HuggingFaceRequestManager -import org.jetbrains.research.testspark.tools.llm.generation.openai.OpenAIRequestManager - -interface RequestManagerFactory { - fun getRequestManager(project: Project): RequestManager -} - -class StandardRequestManagerFactory( - private val project: Project, -) : RequestManagerFactory { - private val llmSettingsState: LLMSettingsState - get() = project.getService(LLMSettingsService::class.java).state - - override fun getRequestManager(project: Project): RequestManager = - when (val platform = LlmSettingsArguments(project).currentLLMPlatformName()) { - llmSettingsState.openAIName -> OpenAIRequestManager(project) - llmSettingsState.grazieName -> GrazieRequestManager(project) - llmSettingsState.huggingFaceName -> HuggingFaceRequestManager(project) - llmSettingsState.geminiName -> GeminiRequestManager(project) - else -> throw IllegalStateException("Unknown selected platform: $platform") - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestSparkRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestSparkRequestManager.kt deleted file mode 100644 index 07b2577cf..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/TestSparkRequestManager.kt +++ /dev/null @@ -1,64 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation - -import com.intellij.openapi.project.Project -import com.intellij.util.io.HttpRequests -import com.intellij.util.io.HttpRequests.HttpStatusException -import org.jetbrains.research.testspark.core.data.TestSparkModule -import org.jetbrains.research.testspark.core.error.HttpError -import org.jetbrains.research.testspark.core.error.Result -import org.jetbrains.research.testspark.core.error.TestSparkError -import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments -import java.net.HttpURLConnection -import java.net.URLConnection - -abstract class TestSparkRequestManager( - project: Project, -) : RequestManager( - token = LlmSettingsArguments(project).getToken(), - llmModel = LlmSettingsArguments(project).getModel(), - ) { - protected abstract val url: String - - protected abstract fun assembleRequestBodyJson(): String - - /** Set headers, tokens, etc. */ - protected abstract fun tuneRequest(connection: URLConnection) - - protected abstract fun assembleResponse( - httpRequest: HttpRequests.Request, - testsAssembler: TestsAssembler, - indicator: CustomProgressIndicator, - errorMonitor: ErrorMonitor, - ) - - protected open fun mapHttpCodeToError(httpCode: Int): TestSparkError = HttpError(httpCode = httpCode, module = TestSparkModule.Llm()) - - override fun send( - prompt: String, - indicator: CustomProgressIndicator, - testsAssembler: TestsAssembler, - errorMonitor: ErrorMonitor, - ): Result = - try { - HttpRequests - .post(url, "application/json") - .tuner { tuneRequest(it) } - .connect { request -> - request.write(assembleRequestBodyJson()) - val connection = request.connection as HttpURLConnection - when (val responseCode = connection.responseCode) { - HttpURLConnection.HTTP_OK -> - Result.Success( - data = assembleResponse(request, testsAssembler, indicator, errorMonitor), - ) - else -> Result.Failure(mapHttpCodeToError(responseCode)) - } - } - } catch (exception: HttpStatusException) { - Result.Failure(HttpError(cause = exception)) - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestBody.kt deleted file mode 100644 index f728156f3..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestBody.kt +++ /dev/null @@ -1,28 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.gemini - -data class GeminiRequest( - val contents: List, -) - -data class GeminiRequestBody( - val parts: List, -) - -data class GeminiChatMessage( - val text: String, -) - -data class GeminiReply( - val content: GeminiReplyContent, - val finishReason: String, - val avgLogprobs: Double, -) - -data class GeminiReplyContent( - val parts: List, - val role: String?, -) - -data class GeminiReplyPart( - val text: String, -) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManager.kt deleted file mode 100644 index 7004f9d0d..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManager.kt +++ /dev/null @@ -1,72 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.gemini - -import com.google.gson.GsonBuilder -import com.google.gson.JsonParser -import com.intellij.openapi.project.Project -import com.intellij.util.io.HttpRequests -import org.jetbrains.research.testspark.core.data.LlmModuleType -import org.jetbrains.research.testspark.core.data.TestSparkModule -import org.jetbrains.research.testspark.core.error.HttpError -import org.jetbrains.research.testspark.core.error.LlmError -import org.jetbrains.research.testspark.core.error.TestSparkError -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.tools.ToolUtils -import org.jetbrains.research.testspark.tools.llm.generation.TestSparkRequestManager -import java.net.HttpURLConnection -import java.net.URLConnection - -class GeminiRequestManager( - project: Project, -) : TestSparkRequestManager(project) { - private val gson = GsonBuilder().create() - - override val url: String - get() { - val baseUrl = "https://generativelanguage.googleapis.com/v1beta/models/" - return "$baseUrl$llmModel:generateContent?key=$token" - } - - override fun tuneRequest(connection: URLConnection) = Unit - - override fun assembleRequestBodyJson(): String { - val messages = chatHistory.map { GeminiChatMessage(it.content) } - val geminiRequest = GeminiRequest(listOf(GeminiRequestBody(messages))) - return gson.toJson(geminiRequest) - } - - override fun mapHttpCodeToError(httpCode: Int): TestSparkError = - when (httpCode) { - HttpURLConnection.HTTP_BAD_REQUEST -> LlmError.PromptTooLong - else -> HttpError(httpCode = httpCode, module = TestSparkModule.Llm(LlmModuleType.Gemini)) - } - - override fun assembleResponse( - httpRequest: HttpRequests.Request, - testsAssembler: TestsAssembler, - indicator: CustomProgressIndicator, - errorMonitor: ErrorMonitor, - ) { - while (true) { - if (ToolUtils.isProcessCanceled(errorMonitor, indicator)) return - - val text = httpRequest.reader.readText() - val result = - gson.fromJson( - JsonParser - .parseString(text) - .asJsonObject["candidates"] - .asJsonArray[0] - .asJsonObject, - GeminiReply::class.java, - ) - - testsAssembler.consume(result.content.parts[0].text) - - if (result.finishReason == "STOP") break - } - - log.debug { testsAssembler.getContent() } - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequest.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequest.kt index 35d226314..1d59110f5 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequest.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequest.kt @@ -1,12 +1,11 @@ package org.jetbrains.research.testspark.tools.llm.generation.grazie -import org.jetbrains.research.testspark.core.test.TestsAssembler +import kotlinx.coroutines.flow.Flow interface GrazieRequest { - fun request( + suspend fun request( token: String, messages: List>, profile: String, - testsAssembler: TestsAssembler, - ): String + ): Flow } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt index 92e7a737d..82d587b87 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/grazie/GrazieRequestManager.kt @@ -1,6 +1,8 @@ package org.jetbrains.research.testspark.tools.llm.generation.grazie -import com.intellij.openapi.project.Project +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.map import org.jetbrains.research.testspark.core.data.ChatMessage import org.jetbrains.research.testspark.core.data.LlmModuleType import org.jetbrains.research.testspark.core.data.TestSparkModule @@ -9,69 +11,50 @@ import org.jetbrains.research.testspark.core.error.LlmError import org.jetbrains.research.testspark.core.error.Result import org.jetbrains.research.testspark.core.error.TestSparkError import org.jetbrains.research.testspark.core.generation.llm.network.RequestManager -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.tools.llm.LlmSettingsArguments +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams -class GrazieRequestManager( - project: Project, -) : RequestManager( - token = LlmSettingsArguments(project).getToken(), - llmModel = LlmSettingsArguments(project).getModel(), - ) { - override fun send( - prompt: String, - indicator: CustomProgressIndicator, - testsAssembler: TestsAssembler, - errorMonitor: ErrorMonitor, - ): Result = - try { - val className = "org.jetbrains.research.grazie.Request" - val request: GrazieRequest = Class.forName(className).getDeclaredConstructor().newInstance() as GrazieRequest +class GrazieRequestManager : RequestManager { + override suspend fun sendRequest( + params: LlmParams, + chatHistory: List, + ): Flow> { + val className = "org.jetbrains.research.grazie.Request" + val request: GrazieRequest = + Class.forName(className).getDeclaredConstructor().newInstance() as GrazieRequest - val requestError = request.request(token, getMessages(), llmModel, testsAssembler) - - if (requestError.isNotEmpty()) { - with(requestError) { - when { - contains("invalid: 401") -> - Result.Failure( - error = HttpError(httpCode = 401), - ) - - contains("invalid: 413 Payload Too Large") -> - Result.Failure( - error = LlmError.PromptTooLong, - ) - contains("Provided prompt is too big for this model") && contains("invalid: 412 Precondition Failed") -> - Result.Failure( - error = LlmError.PromptTooLong, - ) - else -> - Result.Failure( - error = - HttpError( - message = this, - module = TestSparkModule.Llm(LlmModuleType.Grazie), - ), - ) + val messages = + chatHistory.map { + val role = + when (it.role) { + ChatMessage.ChatRole.User -> "user" + ChatMessage.ChatRole.Assistant -> "assistant" } - } - } else { - Result.Success(data = Unit) + (role to it.content) } - } catch (_: ClassNotFoundException) { - Result.Failure(error = LlmError.GrazieNotAvailable) - } - private fun getMessages(): List> = - chatHistory.map { - val role = - when (it.role) { - ChatMessage.ChatRole.User -> "user" - ChatMessage.ChatRole.Assistant -> "assistant" - } - (role to it.content) + return request + .request(params.token, messages, params.model) + .map { Result.Success(data = it) as Result } + .catch { emit(Result.Failure(error = it.toError())) } + } + + companion object { + private fun Throwable.toError(): TestSparkError { + val message = message.toString() + val promptTooLong = message.contains("Provided prompt is too big for this model") + val preconditionFailed = message.contains("invalid: 412 Precondition Failed") + return when { + this is ClassNotFoundException -> LlmError.GrazieNotAvailable + message.contains("invalid: 401") -> HttpError(httpCode = 401) + message.contains("invalid: 413 Payload Too Large") -> LlmError.PromptTooLong + promptTooLong && preconditionFailed -> LlmError.PromptTooLong + + else -> + HttpError( + message = message, + module = TestSparkModule.Llm(LlmModuleType.Grazie), + ) + } } + } } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt deleted file mode 100644 index 4436d9482..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestBody.kt +++ /dev/null @@ -1,33 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.hf - -import org.jetbrains.research.testspark.core.data.ChatMessage - -data class Parameters( - val topProbability: Double, - val temperature: Double, -) - -data class HuggingFaceRequestBody( - val messages: List, - val parameters: Parameters, -) - -/** - * Sets LLM settings required to send inference requests to HF - * For more info, see https://huggingface.co/docs/api-inference/en/detailed_parameters - */ -fun HuggingFaceRequestBody.toMap(): Map = - mapOf( - "inputs" to this.messages.joinToString(separator = "\n") { it.content }, - // TODO: These parameters can be set by the user in the plugin's settings too. - "parameters" to - mapOf( - "top_p" to this.parameters.topProbability, - "temperature" to this.parameters.temperature, - "min_length" to 4096, - "max_length" to 8192, - "max_new_tokens" to 250, - "max_time" to 120.0, - "return_full_text" to false, - ), - ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt deleted file mode 100644 index 65a0c14c0..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/hf/HuggingFaceRequestManager.kt +++ /dev/null @@ -1,91 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.hf - -import com.google.gson.GsonBuilder -import com.google.gson.JsonParser -import com.intellij.openapi.project.Project -import com.intellij.util.io.HttpRequests -import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle -import org.jetbrains.research.testspark.core.data.ChatUserMessage -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.tools.llm.generation.TestSparkRequestManager -import java.net.URLConnection - -/** - * A class to manage requests sent to large language models hosted on HuggingFace - */ -class HuggingFaceRequestManager( - project: Project, -) : TestSparkRequestManager(project) { - // TODO: The user should be able to change these numbers in the plugin's settings - private val topProbability = 0.9 - private val temperature = 0.9 - - override val url: String - get() { - val baseUrl = "https://api-inference.huggingface.co/models/meta-llama/" - return "$baseUrl$llmModel" - } - - override fun assembleRequestBodyJson(): String { - if (chatHistory.size == 1) { - chatHistory[0] = - ChatUserMessage( - createInstructionPrompt( - chatHistory[0].content, - ), - ) - } - val llmRequestBody = - HuggingFaceRequestBody(chatHistory, Parameters(topProbability, temperature)).toMap() - return GsonBuilder().disableHtmlEscaping().create().toJson(llmRequestBody) - } - - override fun tuneRequest(connection: URLConnection) { - connection.setRequestProperty("Authorization", "Bearer $token") - } - - override fun assembleResponse( - httpRequest: HttpRequests.Request, - testsAssembler: TestsAssembler, - indicator: CustomProgressIndicator, - errorMonitor: ErrorMonitor, - ) { - val text = httpRequest.reader.readLine() - val generatedTestCases = - extractLLMGeneratedCode( - JsonParser - .parseString(text) - .asJsonArray[0] - .asJsonObject["generated_text"] - .asString - .trim(), - ) - testsAssembler.consume(generatedTestCases) - } - - /** - * Creates the required prompt for Llama models. For more details see: - * https://huggingface.co/blog/llama2#how-to-prompt-llama-2 - */ - private fun createInstructionPrompt(userMessage: String): String { - // TODO: This is Llama-specific and should support other LLMs hosted on HF too. - return "[INST] <> ${LLMDefaultsBundle.get("huggingFaceInitialSystemPrompt")} <> $userMessage [/INST]" - } - - /** - * Extracts code blocks in LLMs' response. - * Also, it handles the cases where the LLM-generated code does not end with ``` - */ - private fun extractLLMGeneratedCode(text: String): String { - // TODO: This method should support other languages other than Java. - val modifiedText = text.replace("```java", "```").replace("````", "```") - val tripleTickBlockIndex = modifiedText.indexOf("```") - val codePart = modifiedText.substring(tripleTickBlockIndex + 3) - val lines = codePart.lines() - val filteredLines = lines.filter { line -> line != "```" } - val code = filteredLines.joinToString("\n") - return "```\n$code\n```" - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIChoice.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIChoice.kt deleted file mode 100644 index 3fa6a4d4f..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIChoice.kt +++ /dev/null @@ -1,15 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.openai - -import com.google.gson.annotations.SerializedName - -data class OpenAIChoice( - val index: Int, - val delta: Delta, - @SerializedName("finish_reason") - val finishedReason: String, -) - -data class Delta( - val role: String?, - val content: String, -) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt deleted file mode 100644 index f807803c4..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestBody.kt +++ /dev/null @@ -1,39 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.openai - -/** - * Adheres the naming of fields for OpenAI chat completion API and checks the correctness of a `role`. - *
- * Use this class as a carrier of messages that should be sent to OpenAI API. - */ -data class OpenAIChatMessage( - val role: String, - val content: String, -) { - private companion object { - /** - * The API strictly defines the set of roles. - * The `function` role is omitted because it is already deprecated. - * - * See: https://platform.openai.com/docs/api-reference/chat/create - */ - val supportedRoles = listOf("user", "assistant", "system", "tool") - } - - init { - if (!supportedRoles.contains(role)) { - throw IllegalArgumentException( - "'$role' is not supported ${OpenAIChatMessage::class}. Available roles are: ${( - supportedRoles.joinToString( - ", ", - ) { "'$it'" } - )}", - ) - } - } -} - -data class OpenAIRequestBody( - val model: String, - val messages: List, - val stream: Boolean = true, -) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt deleted file mode 100644 index 752ae5811..000000000 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/openai/OpenAIRequestManager.kt +++ /dev/null @@ -1,79 +0,0 @@ -package org.jetbrains.research.testspark.tools.llm.generation.openai - -import com.google.gson.Gson -import com.google.gson.GsonBuilder -import com.google.gson.JsonParser -import com.intellij.openapi.project.Project -import com.intellij.util.io.HttpRequests -import org.jetbrains.research.testspark.core.data.ChatMessage -import org.jetbrains.research.testspark.core.data.LlmModuleType -import org.jetbrains.research.testspark.core.data.TestSparkModule -import org.jetbrains.research.testspark.core.error.HttpError -import org.jetbrains.research.testspark.core.error.LlmError -import org.jetbrains.research.testspark.core.error.TestSparkError -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.TestsAssembler -import org.jetbrains.research.testspark.tools.ToolUtils -import org.jetbrains.research.testspark.tools.llm.generation.TestSparkRequestManager -import java.net.HttpURLConnection -import java.net.URLConnection - -class OpenAIRequestManager( - project: Project, -) : TestSparkRequestManager(project) { - override val url = "https://api.openai.com/v1/chat/completions" - - override fun tuneRequest(connection: URLConnection) { - connection.setRequestProperty("Authorization", "Bearer $token") - } - - override fun assembleRequestBodyJson(): String { - val messages = - chatHistory.map { - val role = - when (it.role) { - ChatMessage.ChatRole.User -> "user" - ChatMessage.ChatRole.Assistant -> "assistant" - } - OpenAIChatMessage(role, it.content) - } - val llmRequestBody = OpenAIRequestBody(llmModel, messages) - return GsonBuilder().create().toJson(llmRequestBody) - } - - override fun mapHttpCodeToError(httpCode: Int): TestSparkError = - when (httpCode) { - HttpURLConnection.HTTP_BAD_REQUEST -> LlmError.PromptTooLong - else -> HttpError(httpCode = httpCode, module = TestSparkModule.Llm(LlmModuleType.OpenAi)) - } - - override fun assembleResponse( - httpRequest: HttpRequests.Request, - testsAssembler: TestsAssembler, - indicator: CustomProgressIndicator, - errorMonitor: ErrorMonitor, - ) { - while (true) { - if (ToolUtils.isProcessCanceled(errorMonitor, indicator)) return - - var text = httpRequest.reader.readLine() - if (text.isEmpty()) continue - text = text.removePrefix("data: ") - - val choices = - Gson().fromJson( - JsonParser - .parseString(text) - .asJsonObject["choices"] - .asJsonArray[0] - .asJsonObject, - OpenAIChoice::class.java, - ) - if (choices.finishedReason == "stop") break - - testsAssembler.consume(choices.delta.content) - } - log.debug { testsAssembler.getContent() } - } -} diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index c95f3864b..040d8e242 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -34,7 +34,9 @@ class JUnitTestSuitePresenter( return testSuite.run { // Add each test - testCases.forEach { testCase -> testBody += "$testCase\n" } + testCases.forEach { testCase -> + if (testCase.isCompilable) testBody += "$testCase\n" + } TestGenerator.create(language).generateCode( project, diff --git a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt index 82f76a4c9..22eb831d0 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/runner/SettingsArgumentsLlmEvoSuiteTest.kt @@ -1,16 +1,9 @@ package org.jetbrains.research.testspark.runner -import com.intellij.openapi.application.ApplicationManager -import com.intellij.testFramework.fixtures.CodeInsightTestFixture -import com.intellij.testFramework.fixtures.IdeaProjectTestFixture -import com.intellij.testFramework.fixtures.IdeaTestFixtureFactory -import com.intellij.testFramework.fixtures.JavaTestFixtureFactory -import com.intellij.testFramework.fixtures.TestFixtureBuilder import org.assertj.core.api.Assertions.assertThat import org.jetbrains.research.testspark.services.EvoSuiteSettingsService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState import org.jetbrains.research.testspark.tools.evosuite.EvoSuiteSettingsArguments -import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance @@ -19,27 +12,14 @@ import org.junit.jupiter.api.TestInstance class SettingsArgumentsLlmEvoSuiteTest { private lateinit var settingsState: EvoSuiteSettingsState - private lateinit var fixture: CodeInsightTestFixture - @BeforeEach fun setUp() { - val projectBuilder: TestFixtureBuilder = - IdeaTestFixtureFactory.getFixtureFactory().createFixtureBuilder("project") - - fixture = JavaTestFixtureFactory.getFixtureFactory().createCodeInsightFixture(projectBuilder.fixture) - fixture.setUp() - - val settingsService = ApplicationManager.getApplication().getService(EvoSuiteSettingsService::class.java) + val settingsService = EvoSuiteSettingsService() settingsService.loadState(EvoSuiteSettingsState()) settingsState = settingsService.state } - @AfterEach - fun tearDown() { - fixture.tearDown() - } - @Test fun testCommandForClass() { val settings = diff --git a/src/test/kotlin/org/jetbrains/research/testspark/settings/EvoSuiteSettingsConfigurableTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/settings/EvoSuiteSettingsConfigurableTest.kt index 4baf54696..e54e95fa5 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/settings/EvoSuiteSettingsConfigurableTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/settings/EvoSuiteSettingsConfigurableTest.kt @@ -1,54 +1,42 @@ package org.jetbrains.research.testspark.settings -import com.intellij.testFramework.fixtures.CodeInsightTestFixture -import com.intellij.testFramework.fixtures.IdeaProjectTestFixture -import com.intellij.testFramework.fixtures.IdeaTestFixtureFactory -import com.intellij.testFramework.fixtures.JavaTestFixtureFactory -import com.intellij.testFramework.fixtures.TestFixtureBuilder +import com.intellij.openapi.project.Project import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.jetbrains.research.testspark.services.EvoSuiteSettingsService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsComponent import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsConfigurable import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsState -import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito +import org.mockito.Mockito.`when` import java.util.stream.Stream @TestInstance(TestInstance.Lifecycle.PER_CLASS) class EvoSuiteSettingsConfigurableTest { private lateinit var settingsComponent: EvoSuiteSettingsComponent private lateinit var settingsState: EvoSuiteSettingsState - private lateinit var fixture: CodeInsightTestFixture private lateinit var settingsConfigurable: EvoSuiteSettingsConfigurable + private val project = Mockito.mock(Project::class.java) + @BeforeEach fun setUp() { - val projectBuilder: TestFixtureBuilder = - IdeaTestFixtureFactory.getFixtureFactory().createFixtureBuilder("project") + val evoSuiteSettingsService = EvoSuiteSettingsService() - fixture = - JavaTestFixtureFactory - .getFixtureFactory() - .createCodeInsightFixture(projectBuilder.fixture) - fixture.setUp() + `when`(project.getService(EvoSuiteSettingsService::class.java)).thenReturn(evoSuiteSettingsService) - settingsConfigurable = EvoSuiteSettingsConfigurable(fixture.project) + settingsConfigurable = EvoSuiteSettingsConfigurable(project) settingsConfigurable.createComponent() settingsConfigurable.reset() settingsComponent = settingsConfigurable.settingsComponent!! - settingsState = fixture.project.getService(EvoSuiteSettingsService::class.java).state - } - - @AfterEach - fun tearDown() { - fixture.tearDown() + settingsState = evoSuiteSettingsService.state } @Test diff --git a/src/test/kotlin/org/jetbrains/research/testspark/settings/PluginSettingsConfigurableTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/settings/PluginSettingsConfigurableTest.kt index bc2c245df..b3c79d5b7 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/settings/PluginSettingsConfigurableTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/settings/PluginSettingsConfigurableTest.kt @@ -1,13 +1,8 @@ package org.jetbrains.research.testspark.settings -import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.components.service -import com.intellij.testFramework.fixtures.CodeInsightTestFixture -import com.intellij.testFramework.fixtures.IdeaProjectTestFixture -import com.intellij.testFramework.fixtures.IdeaTestFixtureFactory -import com.intellij.testFramework.fixtures.JavaTestFixtureFactory -import com.intellij.testFramework.fixtures.TestFixtureBuilder +import com.intellij.openapi.project.Project import org.assertj.core.api.Assertions.assertThat +import org.jetbrains.research.testspark.services.EvoSuiteSettingsService import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.services.PluginSettingsService import org.jetbrains.research.testspark.settings.evosuite.EvoSuiteSettingsComponent @@ -24,6 +19,8 @@ import org.junit.jupiter.api.TestInstance import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource +import org.mockito.Mockito +import org.mockito.Mockito.`when` import java.util.stream.Stream @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -33,40 +30,37 @@ class PluginSettingsConfigurableTest { private lateinit var settingsComponent: PluginSettingsComponent private lateinit var settingsEvoComponent: EvoSuiteSettingsComponent private lateinit var settingsState: PluginSettingsState - private lateinit var fixture: CodeInsightTestFixture private lateinit var settingsApplicationState: LLMSettingsState + private val project = Mockito.mock(Project::class.java) + @BeforeEach fun setUp() { - val projectBuilder: TestFixtureBuilder = - IdeaTestFixtureFactory.getFixtureFactory().createFixtureBuilder("project") - - fixture = - JavaTestFixtureFactory - .getFixtureFactory() - .createCodeInsightFixture(projectBuilder.fixture) - fixture.setUp() + val evoSuiteSettingsService = EvoSuiteSettingsService() + val pluginSettingsService = PluginSettingsService() + val llmSettingsService = LLMSettingsService() - settingsConfigurable = PluginSettingsConfigurable(fixture.project) + `when`(project.getService(EvoSuiteSettingsService::class.java)).thenReturn(evoSuiteSettingsService) + `when`(project.getService(PluginSettingsService::class.java)).thenReturn(pluginSettingsService) + settingsConfigurable = PluginSettingsConfigurable(project) settingsConfigurable.createComponent() settingsConfigurable.reset() - settingsEvoConfigurable = EvoSuiteSettingsConfigurable(fixture.project) + settingsEvoConfigurable = EvoSuiteSettingsConfigurable(project) settingsEvoConfigurable.createComponent() settingsEvoConfigurable.reset() settingsEvoComponent = settingsEvoConfigurable.settingsComponent!! settingsComponent = settingsConfigurable.settingsComponent!! - settingsState = fixture.project.service().state + settingsState = pluginSettingsService.state - settingsApplicationState = ApplicationManager.getApplication().getService(LLMSettingsService::class.java).state + settingsApplicationState = llmSettingsService.state } @AfterEach fun tearDown() { - fixture.tearDown() settingsConfigurable.disposeUIResources() } diff --git a/src/test/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManagerTest.kt b/src/test/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManagerTest.kt index c3314ac56..3d233a9db 100644 --- a/src/test/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManagerTest.kt +++ b/src/test/kotlin/org/jetbrains/research/testspark/tools/llm/generation/gemini/GeminiRequestManagerTest.kt @@ -1,94 +1,60 @@ package org.jetbrains.research.testspark.tools.llm.generation.gemini -import com.intellij.openapi.project.Project -import org.jetbrains.research.testspark.bundles.llm.LLMDefaultsBundle -import org.jetbrains.research.testspark.core.data.JUnitVersion -import org.jetbrains.research.testspark.core.data.TestGenerationData -import org.jetbrains.research.testspark.core.monitor.ErrorMonitor -import org.jetbrains.research.testspark.core.progress.CustomProgressIndicator -import org.jetbrains.research.testspark.core.test.JUnitTestSuiteParser -import org.jetbrains.research.testspark.core.test.SupportedLanguage -import org.jetbrains.research.testspark.core.test.TestsAssembler +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import org.jetbrains.research.testspark.core.data.ChatMessage +import org.jetbrains.research.testspark.core.generation.llm.network.HttpRequestManager +import org.jetbrains.research.testspark.core.generation.llm.network.LlmProvider +import org.jetbrains.research.testspark.core.generation.llm.network.model.LlmParams import org.jetbrains.research.testspark.helpers.LLMHelper -import org.jetbrains.research.testspark.services.LLMSettingsService -import org.jetbrains.research.testspark.settings.llm.LLMSettingsState -import org.jetbrains.research.testspark.tools.llm.generation.JUnitTestsAssembler -import org.junit.jupiter.api.Assertions.assertNotNull import org.junit.jupiter.api.Assertions.assertTrue -import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertAll import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable -import org.mockito.Mockito.mock -import org.mockito.Mockito.`when` @EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = ".+") class GeminiRequestManagerTest { - private lateinit var project: Project - private lateinit var testsAssembler: TestsAssembler - - private val apiKey: String = System.getenv("GOOGLE_API_KEY")!! - - private val indicator = mock(CustomProgressIndicator::class.java) - private val errorMonitor = mock(ErrorMonitor::class.java) - - @BeforeEach - fun setUp() { - project = mock(Project::class.java) - val settingsService = mock(LLMSettingsService::class.java) - val settingsState = mock(LLMSettingsState::class.java) - `when`(settingsState.currentLLMPlatformName).thenReturn(LLMDefaultsBundle.get("geminiName")) - `when`(settingsState.geminiName).thenReturn(LLMDefaultsBundle.get("geminiName")) - `when`(settingsState.geminiToken).thenReturn(apiKey) - `when`(settingsState.geminiModel).thenReturn("gemini-1.5-flash") - `when`(settingsService.state).thenReturn(settingsState) - `when`(project.getService(LLMSettingsService::class.java)).thenReturn(settingsService) - - testsAssembler = - JUnitTestsAssembler( - indicator, - mock(TestGenerationData::class.java), - mock(JUnitTestSuiteParser::class.java), - JUnitVersion.JUnit5, - ) - } - - @Test - fun `test request manager implementation for Google Gemini`() { - val manager = GeminiRequestManager(project) - val prompt = - """ - You are a Java tester. Provide a test case that covers the following code snippet: - - ```java - package com.example; - public class Foo { - public int sign(int x) { - if (x > 0) return 1 - if (x < 0) return -1 - return 0 - } - } - ``` - """.trimIndent() - manager.request( - SupportedLanguage.Java, - prompt, - indicator, - "com.example", - testsAssembler, - false, - errorMonitor, + private val requestManager = HttpRequestManager(llmProvider = LlmProvider.Gemini) + private val llmParams = + LlmParams( + model = "gemini-1.5-flash", + token = System.getenv("GOOGLE_API_KEY")!!, ) - val result = manager.send(prompt, indicator, testsAssembler, errorMonitor) - val llmResult = testsAssembler.getContent() - assertNotNull(result) - assertNotNull(llmResult) - } + @Test + fun `test request manager implementation for Google Gemini`() = + runTest { + val prompt = + """ + You are a Java tester. Provide a test case that covers the following code snippet: + + ```java + package com.example; + public class Foo { + public int sign(int x) { + if (x > 0) return 1 + if (x < 0) return -1 + return 0 + } + } + ``` + """.trimIndent() + val chunks = + requestManager + .sendRequest( + llmParams, + listOf(ChatMessage.createUserMessage(prompt)), + ).toList() + assertAll( + chunks.map { chunk -> + { assertTrue(chunk.isSuccess()) } + }, + ) + } @Test fun `test the retrieved Gemini models`() { - val models = LLMHelper.getGeminiModels(apiKey) + val models = LLMHelper.getGeminiModels(llmParams.token) assertTrue(models.isNotEmpty()) } }