Skip to content

Commit 5755965

Browse files
authored
feat: ctrl+k to pull a model while keeping any changes (#178)
* ctrl+k to pull a model while keeping any changes * ctrl+k to pull a model while keeping any changes * ctrl+k to pull a model while keeping any changes * fix: #175
1 parent 85da549 commit 5755965

File tree

4 files changed

+174
-1
lines changed

4 files changed

+174
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ echo "alias g=gollama" >> ~/.zshrc
116116
- `c`: Copy model
117117
- `U`: Unload all models
118118
- `p`: Pull an existing model
119+
- `ctrl+k`: Pull model & preserve user configuration
119120
- `ctrl+p`: Pull (get) new model
120121
- `P`: Push model
121122
- `n`: Sort by name

app_model.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ func (m *AppModel) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
260260
return m.handlePushModelKey()
261261
case key.Matches(msg, m.keys.PullModel):
262262
return m.handlePullModelKey()
263+
case key.Matches(msg, m.keys.PullKeepConfig):
264+
return m.handlePullKeepConfigKey()
263265
case key.Matches(msg, m.keys.RenameModel):
264266
return m.handleRenameModelKey()
265267
case key.Matches(msg, m.keys.PullNewModel):
@@ -555,6 +557,9 @@ func (m *AppModel) handleUpdateModelKey() (tea.Model, tea.Cmd) {
555557
m.message = fmt.Sprintf("Error updating model: %v", err)
556558
} else {
557559
m.message = message
560+
// Automatically return to main view after editing
561+
m.view = MainView
562+
m.editing = false
558563
}
559564
m.clearScreen()
560565
m.refreshList()
@@ -669,6 +674,18 @@ func (m *AppModel) handlePullModelKey() (tea.Model, tea.Cmd) {
669674
return m, nil
670675
}
671676

677+
// handlePullKeepConfigKey handles the shift+p key to pull a model while preserving user config
678+
func (m *AppModel) handlePullKeepConfigKey() (tea.Model, tea.Cmd) {
679+
logging.DebugLogger.Println("PullKeepConfig key matched")
680+
if item, ok := m.list.SelectedItem().(Model); ok {
681+
m.message = styles.InfoStyle().Render(fmt.Sprintf("Pulling model & preserving config: %s\n", item.Name))
682+
m.pulling = true
683+
m.pullProgress = 0
684+
return m, m.startPullModelPreserveConfig(item.Name)
685+
}
686+
return m, nil
687+
}
688+
672689
func (m *AppModel) handlePullNewModelKey() (tea.Model, tea.Cmd) {
673690
m.pullInput = textinput.New()
674691
m.pullInput.Placeholder = "Enter model name (e.g. llama3:8b-instruct)"

keymap.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type KeyMap struct {
2525
CopyModel key.Binding
2626
PushModel key.Binding
2727
PullModel key.Binding
28+
PullKeepConfig key.Binding
2829
Top key.Binding
2930
AltScreen key.Binding
3031
EditModel key.Binding
@@ -57,6 +58,7 @@ func NewKeyMap() *KeyMap {
5758
LinkModel: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "link (L=all)")),
5859
PushModel: key.NewBinding(key.WithKeys("P"), key.WithHelp("P", "push")),
5960
PullModel: key.NewBinding(key.WithKeys("p"), key.WithHelp("p", "pull")),
61+
PullKeepConfig: key.NewBinding(key.WithKeys("ctrl+k"), key.WithHelp("ctrl+k", "pull & keep config")),
6062
PullNewModel: key.NewBinding(key.WithKeys("ctrl+p"), key.WithHelp("ctrl+p", "pull new model")),
6163
Quit: key.NewBinding(key.WithKeys("q")),
6264
RunModel: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "run")),

operations.go

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ func editModelfile(client *api.Client, modelName string) (string, error) {
235235
// Write the fetched content to a temporary file
236236
tempDir := os.TempDir()
237237
newModelfilePath := filepath.Join(tempDir, fmt.Sprintf("%s_modelfile.txt", modelName))
238+
239+
// Ensure parent directories exist
240+
parentDir := filepath.Dir(newModelfilePath)
241+
if err := os.MkdirAll(parentDir, 0755); err != nil {
242+
return "", fmt.Errorf("error creating directory for modelfile: %v", err)
243+
}
244+
238245
err = os.WriteFile(newModelfilePath, []byte(modelfileContent), 0644)
239246
if err != nil {
240247
return "", fmt.Errorf("error writing modelfile to temp file: %v", err)
@@ -327,7 +334,7 @@ func editModelfile(client *api.Client, modelName string) (string, error) {
327334
// log to the console if we're not in a tea app
328335
fmt.Printf("Model %s updated successfully\n", modelName)
329336

330-
return fmt.Sprintf("Model %s updated successfully, Press 'q' to return to the models list", modelName), nil
337+
return fmt.Sprintf("Model %s updated successfully", modelName), nil
331338
}
332339

333340
func isLocalhost(url string) bool {
@@ -524,6 +531,78 @@ func (m *AppModel) startPullModel(modelName string) tea.Cmd {
524531
}
525532
}
526533

534+
// startPullModelPreserveConfig pulls a model while preserving user-modified configuration
535+
func (m *AppModel) startPullModelPreserveConfig(modelName string) tea.Cmd {
536+
return func() tea.Msg {
537+
ctx := context.Background()
538+
539+
// Step 1: Extract current parameters and template before pulling
540+
logging.InfoLogger.Printf("Extracting parameters for model %s before pulling", modelName)
541+
currentParams, currentTemplate, systemPrompt, err := getModelParamsWithSystem(modelName, m.client)
542+
if err != nil {
543+
logging.ErrorLogger.Printf("Error extracting parameters for model %s: %v", modelName, err)
544+
return pullErrorMsg{fmt.Errorf("failed to extract parameters: %v", err)}
545+
}
546+
547+
// Step 2: Pull the updated model
548+
logging.InfoLogger.Printf("Pulling updated model: %s", modelName)
549+
req := &api.PullRequest{Name: modelName}
550+
err = m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
551+
m.pullProgress = float64(resp.Completed) / float64(resp.Total)
552+
return nil
553+
})
554+
if err != nil {
555+
return pullErrorMsg{err}
556+
}
557+
558+
// Step 3: Apply the saved configuration back to the updated model
559+
logging.InfoLogger.Printf("Restoring configuration for model: %s", modelName)
560+
561+
// Create request with base fields
562+
createReq := &api.CreateRequest{
563+
Model: modelName, // The model to update
564+
From: modelName, // Use the same model name as base (it's now been updated)
565+
}
566+
567+
// Add template if it exists
568+
if currentTemplate != "" {
569+
createReq.Template = currentTemplate
570+
}
571+
572+
// Add system prompt if it exists
573+
if systemPrompt != "" {
574+
createReq.System = systemPrompt
575+
}
576+
577+
// Add parameters if any were found
578+
if len(currentParams) > 0 {
579+
// Convert map[string]string to map[string]any
580+
parameters := make(map[string]any)
581+
for k, v := range currentParams {
582+
// Try to convert numeric values
583+
if floatVal, err := strconv.ParseFloat(v, 64); err == nil {
584+
parameters[k] = floatVal
585+
} else if intVal, err := strconv.Atoi(v); err == nil {
586+
parameters[k] = intVal
587+
} else {
588+
parameters[k] = v
589+
}
590+
}
591+
createReq.Parameters = parameters
592+
}
593+
594+
// Apply the configuration
595+
err = m.client.Create(ctx, createReq, func(resp api.ProgressResponse) error {
596+
return nil
597+
})
598+
if err != nil {
599+
return pullErrorMsg{fmt.Errorf("failed to restore configuration: %v", err)}
600+
}
601+
602+
return pullSuccessMsg{modelName}
603+
}
604+
}
605+
527606
type editorFinishedMsg struct{ err error }
528607

529608
func cleanupSymlinkedModels(lmStudioModelsDir string) {
@@ -796,6 +875,80 @@ func extractTemplateAndSystem(content string) (template string, system string) {
796875
return template, system
797876
}
798877

878+
// getModelParamsWithSystem extracts parameters, template, and system prompt from a model's modelfile
879+
func getModelParamsWithSystem(modelName string, client *api.Client) (map[string]string, string, string, error) {
880+
logging.InfoLogger.Printf("Getting parameters and system prompt for model: %s\n", modelName)
881+
ctx := context.Background()
882+
req := &api.ShowRequest{Name: modelName}
883+
resp, err := client.Show(ctx, req)
884+
if err != nil {
885+
logging.ErrorLogger.Printf("Error getting modelfile for %s: %v\n", modelName, err)
886+
return nil, "", "", err
887+
}
888+
output := []byte(resp.Modelfile)
889+
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
890+
params := make(map[string]string)
891+
var template string
892+
var system string
893+
894+
inTemplate := false
895+
inMultilineTemplate := false
896+
var templateLines []string
897+
898+
for _, line := range lines {
899+
trimmed := strings.TrimSpace(line)
900+
901+
// Handle TEMPLATE directive
902+
if strings.HasPrefix(trimmed, "TEMPLATE") {
903+
if strings.Contains(trimmed, `"""`) {
904+
// Multi-line template
905+
templateContent := strings.TrimPrefix(trimmed, "TEMPLATE ")
906+
templateContent = strings.TrimSpace(templateContent)
907+
if strings.HasPrefix(templateContent, `"""`) {
908+
templateContent = strings.TrimPrefix(templateContent, `"""`)
909+
}
910+
inTemplate = true
911+
inMultilineTemplate = true
912+
if templateContent != "" {
913+
templateLines = append(templateLines, templateContent)
914+
}
915+
} else {
916+
// Single-line template
917+
template = strings.TrimPrefix(trimmed, "TEMPLATE ")
918+
template = strings.Trim(template, `"`)
919+
}
920+
} else if inTemplate {
921+
if inMultilineTemplate && strings.HasSuffix(trimmed, `"""`) {
922+
line = strings.TrimSuffix(line, `"""`)
923+
if line != "" {
924+
templateLines = append(templateLines, line)
925+
}
926+
inTemplate = false
927+
inMultilineTemplate = false
928+
} else {
929+
templateLines = append(templateLines, line)
930+
}
931+
} else if strings.HasPrefix(trimmed, "SYSTEM") {
932+
system = strings.TrimPrefix(trimmed, "SYSTEM ")
933+
// Remove surrounding quotes if present
934+
system = strings.Trim(system, `"`)
935+
} else if strings.HasPrefix(trimmed, "PARAMETER") {
936+
parts := strings.SplitN(trimmed, " ", 3)
937+
if len(parts) >= 3 {
938+
key := parts[1]
939+
value := strings.TrimSpace(parts[2])
940+
params[key] = value
941+
}
942+
}
943+
}
944+
945+
if len(templateLines) > 0 {
946+
template = strings.Join(templateLines, "\n")
947+
}
948+
949+
return params, template, system, nil
950+
}
951+
799952
// extractParameters extracts parameter values from modelfile content
800953
func extractParameters(content string) map[string]any {
801954
parameters := make(map[string]any)

0 commit comments

Comments
 (0)