From 9108415261afa2a56a4f5a8ff27ef72ff0c7f97b Mon Sep 17 00:00:00 2001 From: Artifizer Date: Wed, 2 Jul 2025 14:00:21 +0300 Subject: [PATCH 1/4] feat: show the progress on model load, like in the *lms* python tool --- cli/lms_go.go | 87 +++++++++++++++++++++++++++++++-- pkg/lmstudio/lmstudio_client.go | 54 ++++++++++++++++++++ 2 files changed, 138 insertions(+), 3 deletions(-) diff --git a/cli/lms_go.go b/cli/lms_go.go index 68f7817..44f2357 100644 --- a/cli/lms_go.go +++ b/cli/lms_go.go @@ -177,6 +177,89 @@ func truncateString(s string, maxLen int) string { return s[:maxLen-3] + "..." } +// loadModelWithProgress loads a model and displays a progress bar with model information +func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier string, logger lmstudio.Logger) error { + var modelInfo *lmstudio.Model + var modelDisplayed bool + var lastProgress float64 = -1 + + // Use the client's LoadModelWithProgress method + err := client.LoadModelWithProgress(modelIdentifier, func(progress float64, info *lmstudio.Model) { + // Display model info on first callback + if !modelDisplayed { + modelInfo = info + if modelInfo != nil { + fmt.Printf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), modelInfo.Format) + if modelInfo.Size > 0 { + // Extract format from model info for display + format := modelInfo.Format + if format == "" && modelInfo.Path != "" { + if strings.Contains(modelInfo.Path, "MLX") { + format = "MLX" + } else if strings.Contains(modelInfo.Path, "GGUF") { + format = "GGUF" + } + } + + // Display size and format like in the screenshot + sizeStr := formatSize(modelInfo.Size) + if format != "" { + fmt.Printf("Model: %s (%s)\n", sizeStr, format) + } else { + fmt.Printf("Model: %s\n", sizeStr) + } + } + } else { + fmt.Printf("Loading model \"%s\" ...\n", modelIdentifier) + } + modelDisplayed = true + } + + // Only update progress if it increased significantly to avoid flickering + if progress > lastProgress+0.001 || progress >= 1.0 { + displayProgressBar(progress) + lastProgress = progress + } + + // If model was already loaded, show completion immediately + if progress >= 1.0 { + fmt.Printf("\n✓ Model loaded successfully\n") + } + }) + + if err != nil { + fmt.Printf("\nFailed to load model: %v\n", err) + return err + } + + return nil +} + +// displayProgressBar shows a progress bar similar to the screenshot +func displayProgressBar(progress float64) { + const barWidth = 50 + percentage := progress * 100 + + // Calculate number of filled characters + filled := int(progress * float64(barWidth)) + + // Build the progress bar using block characters like in the screenshot + bar := make([]rune, barWidth) + for i := 0; i < barWidth; i++ { + if i < filled { + bar[i] = '█' // Full block + } else { + bar[i] = '░' // Light shade + } + } + + // Print progress bar with percentage (carriage return to overwrite) + fmt.Printf("\r: [%s] %.2f%%", string(bar), percentage) + + // Force output to be displayed immediately + os.Stdout.Sync() +} + func main() { // Setup code coverage if running instrumented build if coverageFile != "" { @@ -354,12 +437,10 @@ func main() { // Load a model if *loadModel != "" { operation = true - fmt.Printf("Loading model: %s\n", *loadModel) - if err := client.LoadModel(*loadModel); err != nil { + if err := loadModelWithProgress(client, *loadModel, logger); err != nil { logger.Error("Failed to load model: %v", err) os.Exit(1) } - fmt.Printf("Model %s loaded successfully\n", *loadModel) } // Unload a model diff --git a/pkg/lmstudio/lmstudio_client.go b/pkg/lmstudio/lmstudio_client.go index 9a6c22a..9f2c78f 100644 --- a/pkg/lmstudio/lmstudio_client.go +++ b/pkg/lmstudio/lmstudio_client.go @@ -288,6 +288,60 @@ func (c *LMStudioClient) LoadModel(modelIdentifier string) error { return c.waitForModelLoading(channel, modelIdentifier, loadTimeout) } +// LoadModelWithProgress loads a specified model in LM Studio with progress reporting +func (c *LMStudioClient) LoadModelWithProgress(modelIdentifier string, progressCallback func(progress float64, modelInfo *Model)) error { + // Get model information from downloaded models + var modelInfo *Model + downloadedModels, err := c.ListDownloadedModels() + if err != nil { + c.logger.Warn("Failed to get model information: %v", err) + } else { + for _, model := range downloadedModels { + if model.ModelKey == modelIdentifier || model.DisplayName == modelIdentifier || model.ModelName == modelIdentifier { + modelInfo = &model + break + } + } + } + + // Check if the model exists in downloaded models + if err := c.checkModelExists(modelIdentifier); err != nil { + return err + } + + // Check if the model is already loaded + if c.isModelAlreadyLoaded(modelIdentifier) { + // Call progress callback with 100% completion for already loaded model + if progressCallback != nil { + progressCallback(1.0, modelInfo) + } + return nil + } + + // Create a model loading channel with progress callback + c.logger.Debug("Creating model loading channel for: %s", modelIdentifier) + channel, err := c.NewModelLoadingChannel(LLMNamespace, func(progress float64) { + c.logger.Debug("Loading model %s: %.1f%% complete", modelIdentifier, progress*100) + if progressCallback != nil { + progressCallback(progress, modelInfo) + } + }) + if err != nil { + return fmt.Errorf("failed to create model loading channel: %w", err) + } + defer channel.Close() + + // Create the channel and start loading the model + err = channel.CreateChannel(modelIdentifier) + if err != nil { + return fmt.Errorf("failed to start model loading: %w", err) + } + + // Use a longer timeout for model loading - some large models can take several minutes + loadTimeout := 120 * time.Second + return c.waitForModelLoading(channel, modelIdentifier, loadTimeout) +} + // UnloadModel unloads a specified model in LM Studio func (c *LMStudioClient) UnloadModel(modelIdentifier string) error { // Get or create a connection to the LLM namespace From ec00008a0f6b7904488c4695efce433d36cd53f1 Mon Sep 17 00:00:00 2001 From: Artifizer Date: Wed, 2 Jul 2025 14:02:33 +0300 Subject: [PATCH 2/4] test: added the unit tests for the model loading with progress --- pkg/lmstudio/lmstudio_client_test.go | 172 +++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/pkg/lmstudio/lmstudio_client_test.go b/pkg/lmstudio/lmstudio_client_test.go index 5a2d60a..8d9843f 100644 --- a/pkg/lmstudio/lmstudio_client_test.go +++ b/pkg/lmstudio/lmstudio_client_test.go @@ -344,3 +344,175 @@ func TestSendPrompt(t *testing.T) { } } } + +// TestLoadModelWithProgress tests the LoadModelWithProgress method +func TestLoadModelWithProgress(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgress started") + defer fmt.Println("[TEST] TestLoadModelWithProgress finished or failed") + defer func() { + if r := recover(); r != nil { + t.Fatalf("TestLoadModelWithProgress panicked: %v", r) + } + }() + + t.Parallel() + // Add a timeout to prevent hanging forever + done := make(chan struct{}) + go func() { + // Create a mock server that handles model loading + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Track progress callbacks + var progressCallbacks []float64 + var modelInfoCallbacks []*Model + var callbackMutex sync.Mutex + + progressCallback := func(progress float64, modelInfo *Model) { + callbackMutex.Lock() + defer callbackMutex.Unlock() + progressCallbacks = append(progressCallbacks, progress) + modelInfoCallbacks = append(modelInfoCallbacks, modelInfo) + logger.Debug("Progress callback: %.2f%%, model: %v", progress*100, modelInfo != nil) + } + + // Call the method we're testing + err = client.LoadModelWithProgress("mock-model-0.5B", progressCallback) + if err != nil { + t.Fatalf("LoadModelWithProgress failed: %v", err) + } + + // Verify that progress callbacks were called + callbackMutex.Lock() + if len(progressCallbacks) == 0 { + t.Errorf("Expected progress callbacks, got none") + } + + // Verify progress values are reasonable (between 0 and 1) + for i, progress := range progressCallbacks { + if progress < 0 || progress > 1 { + t.Errorf("Progress callback %d has invalid value: %f (should be between 0 and 1)", i, progress) + } + } + + // Verify we got at least one progress update + if len(progressCallbacks) < 1 { + t.Errorf("Expected at least 1 progress callback, got %d", len(progressCallbacks)) + } + + // Verify model info was provided when available + hasModelInfo := false + for _, modelInfo := range modelInfoCallbacks { + if modelInfo != nil { + hasModelInfo = true + if modelInfo.ModelKey == "" { + t.Errorf("Expected model info to have ModelKey") + } + break + } + } + if !hasModelInfo { + t.Errorf("Expected at least one callback with model info") + } + callbackMutex.Unlock() + + close(done) + }() + + // Wait for test completion or timeout + select { + case <-done: + // Test completed successfully + case <-time.After(10 * time.Second): + t.Fatal("TestLoadModelWithProgress: test timed out (possible deadlock or missing mock response)") + } +} + +// TestLoadModelWithProgressAlreadyLoaded tests LoadModelWithProgress when model is already loaded +func TestLoadModelWithProgressAlreadyLoaded(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgressAlreadyLoaded started") + defer fmt.Println("[TEST] TestLoadModelWithProgressAlreadyLoaded finished or failed") + + // Create a mock server + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Track progress callbacks + var progressCallbacks []float64 + var modelInfoCallbacks []*Model + var callbackMutex sync.Mutex + + progressCallback := func(progress float64, modelInfo *Model) { + callbackMutex.Lock() + defer callbackMutex.Unlock() + progressCallbacks = append(progressCallbacks, progress) + modelInfoCallbacks = append(modelInfoCallbacks, modelInfo) + logger.Debug("Progress callback for already loaded: %.2f%%, model: %v", progress*100, modelInfo != nil) + } + + // Use a model that appears in the loaded models list (mock-model-7B) + err = client.LoadModelWithProgress("mock-model-7B", progressCallback) + if err != nil { + t.Fatalf("LoadModelWithProgress failed for already loaded model: %v", err) + } + + // Verify that we got exactly one callback with 100% progress for already loaded model + callbackMutex.Lock() + if len(progressCallbacks) != 1 { + t.Errorf("Expected exactly 1 progress callback for already loaded model, got %d", len(progressCallbacks)) + } + + if len(progressCallbacks) > 0 && progressCallbacks[0] != 1.0 { + t.Errorf("Expected progress to be 1.0 for already loaded model, got %f", progressCallbacks[0]) + } + callbackMutex.Unlock() +} + +// TestLoadModelWithProgressNilCallback tests LoadModelWithProgress with nil callback +func TestLoadModelWithProgressNilCallback(t *testing.T) { + fmt.Println("[TEST] TestLoadModelWithProgressNilCallback started") + defer fmt.Println("[TEST] TestLoadModelWithProgressNilCallback finished or failed") + + // Create a mock server + server := NewMockLMStudioService(t, newMockLogger()) + defer server.Close() + + // Extract the host and port from the server URL + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + // Create a client that connects to our mock server + logger := newMockLogger() + client := NewLMStudioClient(strings.TrimPrefix(serverURL.Host, "http://"), logger) + defer client.Close() + + // Call the method with nil callback (should not crash) + err = client.LoadModelWithProgress("mock-model-0.5B", nil) + if err != nil { + t.Fatalf("LoadModelWithProgress with nil callback failed: %v", err) + } +} From cd3751cc319b8c6e20fa83fb5b48142d193671fc Mon Sep 17 00:00:00 2001 From: Artifizer Date: Wed, 2 Jul 2025 14:14:19 +0300 Subject: [PATCH 3/4] test: fix the unit tests after recent CLI tool changes --- cli/lms_go_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cli/lms_go_test.go b/cli/lms_go_test.go index 720a083..fcbb43d 100644 --- a/cli/lms_go_test.go +++ b/cli/lms_go_test.go @@ -348,12 +348,12 @@ func testModelLoadingUnloading(t *testing.T) { return } - if !strings.Contains(stdout, fmt.Sprintf("Loading model: %s", testModel)) { - t.Errorf("Expected output to contain 'Loading model: %s', got:\n%s", testModel, stdout) + if !strings.Contains(stdout, "Loading model") { + t.Errorf("Expected output to contain 'Loading model'\n%s", stdout) } - if !strings.Contains(stdout, "loaded successfully") { - t.Errorf("Expected output to contain 'loaded successfully', got:\n%s", stdout) + if !strings.Contains(stdout, testModel) { + t.Errorf("Expected output to contain '%s', got:\n%s", testModel, stdout) } fmt.Printf("Step 3: Testing prompt with loaded model '%s'...\n", testModel) From 4045e7270297ec62919412385428d13779cd4178 Mon Sep 17 00:00:00 2001 From: Artifizer Date: Wed, 2 Jul 2025 14:14:53 +0300 Subject: [PATCH 4/4] feat: added -q|-quiet mode support to hide all the tool outputs --- cli/lms_go.go | 87 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/cli/lms_go.go b/cli/lms_go.go index 44f2357..5e837ad 100644 --- a/cli/lms_go.go +++ b/cli/lms_go.go @@ -14,6 +14,23 @@ import ( // coverageFile is set at build time via -ldflags for instrumented builds var coverageFile string +// Global quiet flag +var quietMode bool + +// quietPrintf prints to stdout only if not in quiet mode +func quietPrintf(format string, args ...interface{}) { + if !quietMode { + fmt.Printf(format, args...) + } +} + +// quietPrintln prints to stdout only if not in quiet mode +func quietPrintln(args ...interface{}) { + if !quietMode { + fmt.Println(args...) + } +} + // formatSize formats file size in a human-readable format func formatSize(size int64) string { if size == 0 { @@ -47,27 +64,27 @@ func formatMaxContext(maxContext int) string { func printTableHeader(columns []string, widths []int) { // Print header for i, col := range columns { - fmt.Printf("%-*s", widths[i], col) + quietPrintf("%-*s", widths[i], col) if i < len(columns)-1 { - fmt.Printf(" | ") + quietPrintf(" | ") } } - fmt.Println() + quietPrintln() // Print separator totalWidth := 0 for i, width := range widths { for j := 0; j < width; j++ { - fmt.Print("-") + quietPrintf("-") } if i < len(widths)-1 { - fmt.Print("-+-") + quietPrintf("-+-") totalWidth += width + 3 } else { totalWidth += width } } - fmt.Println() + quietPrintln() } // printModels prints models in a nice table format or JSON @@ -78,13 +95,13 @@ func printModels(models []lmstudio.Model, title string, jsonOutput bool) { // Ad fmt.Fprintf(os.Stderr, "Error marshalling to JSON: %v\n", err) os.Exit(1) // Or handle error more gracefully } - fmt.Println(string(jsonData)) + fmt.Println(string(jsonData)) // Always print JSON output regardless of quiet mode return } - fmt.Printf("\n%s:\n", title) + quietPrintf("\n%s:\n", title) if len(models) == 0 { - fmt.Printf("No %s found\n", strings.ToLower(title)) + quietPrintf("No %s found\n", strings.ToLower(title)) return } @@ -158,7 +175,7 @@ func printModels(models []lmstudio.Model, title string, jsonOutput bool) { // Ad } // Print the row - fmt.Printf("%-*s | %-15s | %-10s | %-10s | %-10s | %-50s\n", + quietPrintf("%-*s | %-15s | %-10s | %-10s | %-10s | %-50s\n", longestModelName, truncateString(name, longestModelName), truncateString(modelType, 15), @@ -189,7 +206,7 @@ func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier stri if !modelDisplayed { modelInfo = info if modelInfo != nil { - fmt.Printf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), modelInfo.Format) + quietPrintf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), modelInfo.Format) if modelInfo.Size > 0 { // Extract format from model info for display format := modelInfo.Format @@ -204,13 +221,13 @@ func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier stri // Display size and format like in the screenshot sizeStr := formatSize(modelInfo.Size) if format != "" { - fmt.Printf("Model: %s (%s)\n", sizeStr, format) + quietPrintf("Model: %s (%s)\n", sizeStr, format) } else { - fmt.Printf("Model: %s\n", sizeStr) + quietPrintf("Model: %s\n", sizeStr) } } } else { - fmt.Printf("Loading model \"%s\" ...\n", modelIdentifier) + quietPrintf("Loading model \"%s\" ...\n", modelIdentifier) } modelDisplayed = true } @@ -223,12 +240,12 @@ func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier stri // If model was already loaded, show completion immediately if progress >= 1.0 { - fmt.Printf("\n✓ Model loaded successfully\n") + quietPrintf("\n✓ Model loaded successfully\n") } }) if err != nil { - fmt.Printf("\nFailed to load model: %v\n", err) + quietPrintf("\nFailed to load model: %v\n", err) return err } @@ -237,6 +254,10 @@ func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier stri // displayProgressBar shows a progress bar similar to the screenshot func displayProgressBar(progress float64) { + if quietMode { + return // Don't display progress bar in quiet mode + } + const barWidth = 50 percentage := progress * 100 @@ -263,7 +284,7 @@ func displayProgressBar(progress float64) { func main() { // Setup code coverage if running instrumented build if coverageFile != "" { - fmt.Printf("Running with code coverage. Data will be written to: %s\n", coverageFile) + quietPrintf("Running with code coverage. Data will be written to: %s\n", coverageFile) } // Define command-line flags @@ -284,11 +305,16 @@ func main() { waitForInterrupt := flag.Bool("wait", false, "Wait for Ctrl+C to exit after command execution") checkStatus := flag.Bool("status", false, "Check if the LM Studio service is running") showVersion := flag.Bool("version", false, "Show version information") - jsonOutput := flag.Bool("json", false, "Output list commands in JSON format") // Added this flag + jsonOutput := flag.Bool("json", false, "Output list commands in JSON format") + quiet := flag.Bool("q", false, "Quiet mode - suppress all stdout messages except JSON output and errors") + quietLong := flag.Bool("quiet", false, "Quiet mode - suppress all stdout messages except JSON output and errors") // Parse command line flags flag.Parse() + // Set quiet mode if either -q or -quiet is specified + quietMode = *quiet || *quietLong + // Show help if requested if flag.NFlag() == 0 { fmt.Println("LM Studio Models CLI") @@ -446,30 +472,30 @@ func main() { // Unload a model if *unloadModel != "" { operation = true - fmt.Printf("Unloading model: %s\n", *unloadModel) + quietPrintf("Unloading model: %s\n", *unloadModel) if err := client.UnloadModel(*unloadModel); err != nil { // Check if the error is about the model not being found if strings.Contains(err.Error(), "No model found that fits the query") { - fmt.Printf("Model %s is not currently loaded. No action needed.\n", *unloadModel) + quietPrintf("Model %s is not currently loaded. No action needed.\n", *unloadModel) } else { logger.Error("Failed to unload model: %v", err) os.Exit(1) } } else { - fmt.Printf("Model %s unloaded successfully\n", *unloadModel) + quietPrintf("Model %s unloaded successfully\n", *unloadModel) } } // Handle unload all models command if *unloadAll { operation = true - fmt.Printf("Unloading all loaded models...\n") + quietPrintf("Unloading all loaded models...\n") err := client.UnloadAllModels() if err != nil { logger.Error("Failed to unload all models: %v", err) os.Exit(1) } - fmt.Printf("Unloaded all models successfully\n") + quietPrintf("Unloaded all models successfully\n") } // Handle prompt (new format with separate model and prompt options) @@ -498,15 +524,16 @@ func main() { modelIdentifier = models[0].ModelKey } - fmt.Printf("No model specified, using first loaded model: %s\n", modelIdentifier) + quietPrintf("No model specified, using first loaded model: %s\n", modelIdentifier) } - fmt.Printf("\nSending prompt to model: %s, temperature: %.2f\n", modelIdentifier, *temperature) - fmt.Printf("Prompt: %s\n", *promptText) - fmt.Println("Response:") + quietPrintf("\nSending prompt to model: %s, temperature: %.2f\n", modelIdentifier, *temperature) + quietPrintf("Prompt: %s\n", *promptText) + quietPrintln("Response:") // Create a callback to print tokens as they arrive callback := func(token string) { + // Always print tokens regardless of quiet mode - this is the actual output fmt.Print(token) } @@ -514,7 +541,7 @@ func main() { logger.Error("Failed to send prompt: %v", err) os.Exit(1) } - fmt.Println("") + quietPrintln("") // Print newline only in non-quiet mode } // If no operation was specified, list all loaded models as the default behavior @@ -532,11 +559,11 @@ func main() { // Wait for Ctrl+C to exit if requested if *waitForInterrupt { - fmt.Println("\nPress Ctrl+C to exit") + quietPrintln("\nPress Ctrl+C to exit") c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) <-c - fmt.Println("\nShutting down...") + quietPrintln("\nShutting down...") } // Write coverage data if running instrumented build