Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 135 additions & 27 deletions cli/lms_go.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ import (
// coverageFile is set at build time via -ldflags for instrumented builds
var coverageFile string

// Global quiet flag
var quietMode bool

// quietPrintf prints to stdout only if not in quiet mode
func quietPrintf(format string, args ...interface{}) {
if !quietMode {
fmt.Printf(format, args...)
}
}

// quietPrintln prints to stdout only if not in quiet mode
func quietPrintln(args ...interface{}) {
if !quietMode {
fmt.Println(args...)
}
}

// formatSize formats file size in a human-readable format
func formatSize(size int64) string {
if size == 0 {
Expand Down Expand Up @@ -47,27 +64,27 @@ func formatMaxContext(maxContext int) string {
func printTableHeader(columns []string, widths []int) {
// Print header
for i, col := range columns {
fmt.Printf("%-*s", widths[i], col)
quietPrintf("%-*s", widths[i], col)
if i < len(columns)-1 {
fmt.Printf(" | ")
quietPrintf(" | ")
}
}
fmt.Println()
quietPrintln()

// Print separator
totalWidth := 0
for i, width := range widths {
for j := 0; j < width; j++ {
fmt.Print("-")
quietPrintf("-")
}
if i < len(widths)-1 {
fmt.Print("-+-")
quietPrintf("-+-")
totalWidth += width + 3
} else {
totalWidth += width
}
}
fmt.Println()
quietPrintln()
}

// printModels prints models in a nice table format or JSON
Expand All @@ -78,13 +95,13 @@ func printModels(models []lmstudio.Model, title string, jsonOutput bool) { // Ad
fmt.Fprintf(os.Stderr, "Error marshalling to JSON: %v\n", err)
os.Exit(1) // Or handle error more gracefully
}
fmt.Println(string(jsonData))
fmt.Println(string(jsonData)) // Always print JSON output regardless of quiet mode
return
}

fmt.Printf("\n%s:\n", title)
quietPrintf("\n%s:\n", title)
if len(models) == 0 {
fmt.Printf("No %s found\n", strings.ToLower(title))
quietPrintf("No %s found\n", strings.ToLower(title))
return
}

Expand Down Expand Up @@ -158,7 +175,7 @@ func printModels(models []lmstudio.Model, title string, jsonOutput bool) { // Ad
}

// Print the row
fmt.Printf("%-*s | %-15s | %-10s | %-10s | %-10s | %-50s\n",
quietPrintf("%-*s | %-15s | %-10s | %-10s | %-10s | %-50s\n",
longestModelName,
truncateString(name, longestModelName),
truncateString(modelType, 15),
Expand All @@ -177,10 +194,97 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen-3] + "..."
}

// loadModelWithProgress loads a model and displays a progress bar with model information
func loadModelWithProgress(client *lmstudio.LMStudioClient, modelIdentifier string, logger lmstudio.Logger) error {
var modelInfo *lmstudio.Model
var modelDisplayed bool
var lastProgress float64 = -1

// Use the client's LoadModelWithProgress method
err := client.LoadModelWithProgress(modelIdentifier, func(progress float64, info *lmstudio.Model) {
// Display model info on first callback
if !modelDisplayed {
modelInfo = info
if modelInfo != nil {
quietPrintf("Loading model \"%s\" (size: %s, format: %s) ...\n", modelInfo.ModelKey, formatSize(modelInfo.Size), modelInfo.Format)
if modelInfo.Size > 0 {
// Extract format from model info for display
format := modelInfo.Format
if format == "" && modelInfo.Path != "" {
if strings.Contains(modelInfo.Path, "MLX") {
format = "MLX"
} else if strings.Contains(modelInfo.Path, "GGUF") {
format = "GGUF"
}
}

// Display size and format like in the screenshot
sizeStr := formatSize(modelInfo.Size)
if format != "" {
quietPrintf("Model: %s (%s)\n", sizeStr, format)
} else {
quietPrintf("Model: %s\n", sizeStr)
}
}
} else {
quietPrintf("Loading model \"%s\" ...\n", modelIdentifier)
}
modelDisplayed = true
}

// Only update progress if it increased significantly to avoid flickering
if progress > lastProgress+0.001 || progress >= 1.0 {
displayProgressBar(progress)
lastProgress = progress
}

// If model was already loaded, show completion immediately
if progress >= 1.0 {
quietPrintf("\n✓ Model loaded successfully\n")
}
})

if err != nil {
quietPrintf("\nFailed to load model: %v\n", err)
return err
}

return nil
}

// displayProgressBar shows a progress bar similar to the screenshot
func displayProgressBar(progress float64) {
if quietMode {
return // Don't display progress bar in quiet mode
}

const barWidth = 50
percentage := progress * 100

// Calculate number of filled characters
filled := int(progress * float64(barWidth))

// Build the progress bar using block characters like in the screenshot
bar := make([]rune, barWidth)
for i := 0; i < barWidth; i++ {
if i < filled {
bar[i] = '█' // Full block
} else {
bar[i] = '░' // Light shade
}
}

// Print progress bar with percentage (carriage return to overwrite)
fmt.Printf("\r: [%s] %.2f%%", string(bar), percentage)

// Force output to be displayed immediately
os.Stdout.Sync()
}

func main() {
// Setup code coverage if running instrumented build
if coverageFile != "" {
fmt.Printf("Running with code coverage. Data will be written to: %s\n", coverageFile)
quietPrintf("Running with code coverage. Data will be written to: %s\n", coverageFile)
}

// Define command-line flags
Expand All @@ -201,11 +305,16 @@ func main() {
waitForInterrupt := flag.Bool("wait", false, "Wait for Ctrl+C to exit after command execution")
checkStatus := flag.Bool("status", false, "Check if the LM Studio service is running")
showVersion := flag.Bool("version", false, "Show version information")
jsonOutput := flag.Bool("json", false, "Output list commands in JSON format") // Added this flag
jsonOutput := flag.Bool("json", false, "Output list commands in JSON format")
quiet := flag.Bool("q", false, "Quiet mode - suppress all stdout messages except JSON output and errors")
quietLong := flag.Bool("quiet", false, "Quiet mode - suppress all stdout messages except JSON output and errors")

// Parse command line flags
flag.Parse()

// Set quiet mode if either -q or -quiet is specified
quietMode = *quiet || *quietLong

// Show help if requested
if flag.NFlag() == 0 {
fmt.Println("LM Studio Models CLI")
Expand Down Expand Up @@ -354,41 +463,39 @@ func main() {
// Load a model
if *loadModel != "" {
operation = true
fmt.Printf("Loading model: %s\n", *loadModel)
if err := client.LoadModel(*loadModel); err != nil {
if err := loadModelWithProgress(client, *loadModel, logger); err != nil {
logger.Error("Failed to load model: %v", err)
os.Exit(1)
}
fmt.Printf("Model %s loaded successfully\n", *loadModel)
}

// Unload a model
if *unloadModel != "" {
operation = true
fmt.Printf("Unloading model: %s\n", *unloadModel)
quietPrintf("Unloading model: %s\n", *unloadModel)
if err := client.UnloadModel(*unloadModel); err != nil {
// Check if the error is about the model not being found
if strings.Contains(err.Error(), "No model found that fits the query") {
fmt.Printf("Model %s is not currently loaded. No action needed.\n", *unloadModel)
quietPrintf("Model %s is not currently loaded. No action needed.\n", *unloadModel)
} else {
logger.Error("Failed to unload model: %v", err)
os.Exit(1)
}
} else {
fmt.Printf("Model %s unloaded successfully\n", *unloadModel)
quietPrintf("Model %s unloaded successfully\n", *unloadModel)
}
}

// Handle unload all models command
if *unloadAll {
operation = true
fmt.Printf("Unloading all loaded models...\n")
quietPrintf("Unloading all loaded models...\n")
err := client.UnloadAllModels()
if err != nil {
logger.Error("Failed to unload all models: %v", err)
os.Exit(1)
}
fmt.Printf("Unloaded all models successfully\n")
quietPrintf("Unloaded all models successfully\n")
}

// Handle prompt (new format with separate model and prompt options)
Expand Down Expand Up @@ -417,23 +524,24 @@ func main() {
modelIdentifier = models[0].ModelKey
}

fmt.Printf("No model specified, using first loaded model: %s\n", modelIdentifier)
quietPrintf("No model specified, using first loaded model: %s\n", modelIdentifier)
}

fmt.Printf("\nSending prompt to model: %s, temperature: %.2f\n", modelIdentifier, *temperature)
fmt.Printf("Prompt: %s\n", *promptText)
fmt.Println("Response:")
quietPrintf("\nSending prompt to model: %s, temperature: %.2f\n", modelIdentifier, *temperature)
quietPrintf("Prompt: %s\n", *promptText)
quietPrintln("Response:")

// Create a callback to print tokens as they arrive
callback := func(token string) {
// Always print tokens regardless of quiet mode - this is the actual output
fmt.Print(token)
}

if err := client.SendPrompt(modelIdentifier, *promptText, *temperature, callback); err != nil {
logger.Error("Failed to send prompt: %v", err)
os.Exit(1)
}
fmt.Println("")
quietPrintln("") // Print newline only in non-quiet mode
}

// If no operation was specified, list all loaded models as the default behavior
Expand All @@ -451,11 +559,11 @@ func main() {

// Wait for Ctrl+C to exit if requested
if *waitForInterrupt {
fmt.Println("\nPress Ctrl+C to exit")
quietPrintln("\nPress Ctrl+C to exit")
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
<-c
fmt.Println("\nShutting down...")
quietPrintln("\nShutting down...")
}

// Write coverage data if running instrumented build
Expand Down
8 changes: 4 additions & 4 deletions cli/lms_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ func testModelLoadingUnloading(t *testing.T) {
return
}

if !strings.Contains(stdout, fmt.Sprintf("Loading model: %s", testModel)) {
t.Errorf("Expected output to contain 'Loading model: %s', got:\n%s", testModel, stdout)
if !strings.Contains(stdout, "Loading model") {
t.Errorf("Expected output to contain 'Loading model'\n%s", stdout)
}

if !strings.Contains(stdout, "loaded successfully") {
t.Errorf("Expected output to contain 'loaded successfully', got:\n%s", stdout)
if !strings.Contains(stdout, testModel) {
t.Errorf("Expected output to contain '%s', got:\n%s", testModel, stdout)
}

fmt.Printf("Step 3: Testing prompt with loaded model '%s'...\n", testModel)
Expand Down
54 changes: 54 additions & 0 deletions pkg/lmstudio/lmstudio_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,60 @@ func (c *LMStudioClient) LoadModel(modelIdentifier string) error {
return c.waitForModelLoading(channel, modelIdentifier, loadTimeout)
}

// LoadModelWithProgress loads a specified model in LM Studio with progress reporting
func (c *LMStudioClient) LoadModelWithProgress(modelIdentifier string, progressCallback func(progress float64, modelInfo *Model)) error {
// Get model information from downloaded models
var modelInfo *Model
downloadedModels, err := c.ListDownloadedModels()
if err != nil {
c.logger.Warn("Failed to get model information: %v", err)
} else {
for _, model := range downloadedModels {
if model.ModelKey == modelIdentifier || model.DisplayName == modelIdentifier || model.ModelName == modelIdentifier {
modelInfo = &model
break
}
}
}

// Check if the model exists in downloaded models
if err := c.checkModelExists(modelIdentifier); err != nil {
return err
}

// Check if the model is already loaded
if c.isModelAlreadyLoaded(modelIdentifier) {
// Call progress callback with 100% completion for already loaded model
if progressCallback != nil {
progressCallback(1.0, modelInfo)
}
return nil
}

// Create a model loading channel with progress callback
c.logger.Debug("Creating model loading channel for: %s", modelIdentifier)
channel, err := c.NewModelLoadingChannel(LLMNamespace, func(progress float64) {
c.logger.Debug("Loading model %s: %.1f%% complete", modelIdentifier, progress*100)
if progressCallback != nil {
progressCallback(progress, modelInfo)
}
})
if err != nil {
return fmt.Errorf("failed to create model loading channel: %w", err)
}
defer channel.Close()

// Create the channel and start loading the model
err = channel.CreateChannel(modelIdentifier)
if err != nil {
return fmt.Errorf("failed to start model loading: %w", err)
}

// Use a longer timeout for model loading - some large models can take several minutes
loadTimeout := 120 * time.Second
return c.waitForModelLoading(channel, modelIdentifier, loadTimeout)
}

// UnloadModel unloads a specified model in LM Studio
func (c *LMStudioClient) UnloadModel(modelIdentifier string) error {
// Get or create a connection to the LLM namespace
Expand Down
Loading