diff --git a/api/compute/filter.go b/api/compute/filter.go index 9f39c5dc5..a608095ad 100644 --- a/api/compute/filter.go +++ b/api/compute/filter.go @@ -85,7 +85,7 @@ func filterData(client *compute.Client, ds *api.Dataset, filterParams *api.Filte // output the filtered results as the data in the filtered dataset _, outputDataFile := getPreFilteringOutputDataFile(outputFolder) - err = util.CopyFile(filteredData, outputDataFile) + err = util.CopyFile(filteredData.ResultURI, outputDataFile) if err != nil { return "", nil, err } diff --git a/api/compute/pipeline.go b/api/compute/pipeline.go index 1be073ca9..fbd979de7 100644 --- a/api/compute/pipeline.go +++ b/api/compute/pipeline.go @@ -115,10 +115,17 @@ type QueueItem struct { // QueueResponse represents the result from processing a queue item. type QueueResponse struct { - Output interface{} + Output *PipelineOutput Error error } +// PipelineOutput represents an output from executing a queued pipeline. +type PipelineOutput struct { + SolutionID string + ResultURI string + FittedSolutionID string +} + // Queue uses a buffered channel to queue tasks and provides the result via channels. type Queue struct { mu sync.RWMutex @@ -194,6 +201,9 @@ func (q *Queue) Done() { // InitializeCache sets up an empty cache or if a source file provided, reads // the cache from the source file. func InitializeCache(sourceFile string, readEnabled bool) error { + // register the output type for the cache! + gob.Register(&PipelineOutput{}) + var c *gc.Cache if util.FileExists(sourceFile) { b, err := ioutil.ReadFile(sourceFile) @@ -234,7 +244,7 @@ func InitializeQueue(config *env.Config) { // SubmitPipeline executes pipelines using the client and returns the result URI. func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce []string, searchRequest *pipeline.SearchSolutionsRequest, - fullySpecifiedStep *description.FullySpecifiedPipeline, allowedValueTypes []string, shouldCache bool) (string, error) { + fullySpecifiedStep *description.FullySpecifiedPipeline, allowedValueTypes []string, shouldCache bool) (*PipelineOutput, error) { request := compute.NewExecPipelineRequest(datasets, datasetsProduce, fullySpecifiedStep.Pipeline) @@ -254,12 +264,12 @@ func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce [ if cache.readEnabled { if shouldCache { if err != nil { - return "", err + return nil, err } entry, found := cache.cache.Get(hashedPipelineUniqueKey) if found { log.Infof("returning cached entry for pipeline") - return entry.(string), nil + return entry.(*PipelineOutput), nil } } } else { @@ -268,7 +278,7 @@ func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce [ // get equivalency key for enqueuing hashedPipelineEquivKey, err := queueTask.hashEquivalent() if err != nil { - return "", err + return nil, err } resultChan := queue.Enqueue(hashedPipelineEquivKey, queueTask) @@ -276,17 +286,16 @@ func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce [ result := <-resultChan if result.Error != nil { - return "", result.Error + return nil, result.Error } - datasetURI := result.Output.(string) - cache.cache.Set(hashedPipelineUniqueKey, datasetURI, gc.DefaultExpiration) + cache.cache.Set(hashedPipelineUniqueKey, result.Output, gc.DefaultExpiration) err = cache.PersistCache() if err != nil { log.Warnf("error persisting cache: %v", err) } - return datasetURI, nil + return result.Output, nil } func runPipelineQueue(queue *Queue) { @@ -316,6 +325,8 @@ func runPipelineQueue(queue *Queue) { // listen for completion var errPipeline error var datasetURI string + var fittedSolutionID string + var solutionID string err = pipelineTask.request.Listen(func(status compute.ExecPipelineStatus) { // check for error if status.Error != nil { @@ -324,6 +335,8 @@ func runPipelineQueue(queue *Queue) { if status.Progress == compute.RequestCompletedStatus { datasetURI = status.ResultURI + fittedSolutionID = status.FittedSolutionID + solutionID = status.SolutionID } }) if err != nil { @@ -342,7 +355,11 @@ func runPipelineQueue(queue *Queue) { datasetURI = strings.Replace(datasetURI, "file://", "", -1) - queueTask.returnResult(&QueueResponse{Output: datasetURI}) + queueTask.returnResult(&QueueResponse{&PipelineOutput{ + ResultURI: datasetURI, + FittedSolutionID: fittedSolutionID, + SolutionID: solutionID, + }, nil}) } log.Infof("ending queue processing") diff --git a/api/compute/search.go b/api/compute/search.go index 2d4953d0f..e84aff02c 100644 --- a/api/compute/search.go +++ b/api/compute/search.go @@ -39,6 +39,7 @@ type searchResult struct { type pipelineSearchContext struct { searchID string dataset string + task []string storageName string sourceDatasetURI string trainDatasetURI string @@ -278,7 +279,7 @@ func (s *SolutionRequest) dispatchSolutionSearchPipeline(statusChan chan Solutio if ok { // reformat result to have one row per d3m index since confidences // can produce one row / class - resultURI, err = reformatResult(resultURI) + resultURI, err = reformatResult(resultURI, s.TargetFeature.HeaderName) if err != nil { return nil, err } diff --git a/api/compute/segment.go b/api/compute/segment.go new file mode 100644 index 000000000..a68f3f314 --- /dev/null +++ b/api/compute/segment.go @@ -0,0 +1,58 @@ +// +// Copyright © 2021 Uncharted Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "strconv" + + "github.com/pkg/errors" + + "github.com/uncharted-distil/distil/api/util/imagery" +) + +// BuildSegmentationImage uses the raw segmentation output to build an image layer. +func BuildSegmentationImage(rawSegmentation [][]interface{}) (map[string][]byte, error) { + // output is mapping of d3m index to new segmentation layer + output := map[string][]byte{} + // need to output all the masks as images + for _, r := range rawSegmentation[1:] { + // create the image that captures the mask + d3mIndex := r[0].(string) + rawMask := r[1].([]interface{}) + rawFloats := make([][]float64, len(rawMask)) + for i, f := range rawMask { + dataF := f.([]interface{}) + nestedFloats := make([]float64, len(dataF)) + for j, nf := range dataF { + fp, err := strconv.ParseFloat(nf.(string), 64) + if err != nil { + return nil, errors.Wrapf(err, "unable to parse mask") + } + nestedFloats[j] = fp + } + rawFloats[i] = nestedFloats + } + + filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(255)) + imageBytes, err := imagery.ImageToPNG(filter) + if err != nil { + return nil, err + } + output[d3mIndex] = imageBytes + } + + return output, nil +} diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go index 22dbd009b..f4fb7a3b7 100644 --- a/api/compute/solution_request.go +++ b/api/compute/solution_request.go @@ -18,6 +18,7 @@ package compute import ( "context" "fmt" + "os" "path" "path/filepath" "strconv" @@ -33,8 +34,11 @@ import ( "github.com/uncharted-distil/distil-compute/pipeline" "github.com/uncharted-distil/distil-compute/primitive/compute" "github.com/uncharted-distil/distil-compute/primitive/compute/description" + "github.com/uncharted-distil/distil-compute/primitive/compute/result" + "github.com/uncharted-distil/distil/api/env" api "github.com/uncharted-distil/distil/api/model" "github.com/uncharted-distil/distil/api/serialization" + "github.com/uncharted-distil/distil/api/util" "github.com/uncharted-distil/distil/api/util/json" log "github.com/unchartedsoftware/plog" "google.golang.org/grpc/codes" @@ -323,7 +327,8 @@ func (s *SolutionRequest) createPreprocessingPipeline(featureVariables []*model. } // GeneratePredictions produces predictions using the specified. -func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID string, client *compute.Client) (*PredictionResult, error) { +func GeneratePredictions(datasetID string, datasetURI string, solutionID string, fittedSolutionID string, task *Task, + targetName string, metaStorage api.MetadataStorage, dataStorage api.DataStorage, client *compute.Client) (*PredictionResult, error) { // check if the solution can be explained desc, err := client.GetSolutionDescription(context.Background(), solutionID) if err != nil { @@ -355,7 +360,14 @@ func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID if err != nil { return nil, err } - resultURI, err = reformatResult(resultURI) + + // segmentation results need to be reduced to tagging segmented images + if HasTaskType(task, compute.SegmentationTask) && isSegmentationOutput(resultURI) { + resultURI, err = createSegmentationResult(datasetID, resultURI, targetName, metaStorage, dataStorage) + } else { + resultURI, err = reformatResult(resultURI, targetName) + } + if err != nil { return nil, err } @@ -483,10 +495,11 @@ func (s *SolutionRequest) persistSolutionStatus(statusChan chan SolutionStatus, } } -func (s *SolutionRequest) persistRequestError(statusChan chan SolutionStatus, solutionStorage api.SolutionStorage, searchID string, dataset string, err error) { +func (s *SolutionRequest) persistRequestError(statusChan chan SolutionStatus, + solutionStorage api.SolutionStorage, searchID string, dataset string, task []string, err error) { // persist the updated state // NOTE: ignoring error - _ = solutionStorage.PersistRequest(searchID, dataset, compute.RequestErroredStatus, time.Now()) + _ = solutionStorage.PersistRequest(searchID, dataset, task, compute.RequestErroredStatus, time.Now()) // notify of error statusChan <- SolutionStatus{ @@ -497,12 +510,13 @@ func (s *SolutionRequest) persistRequestError(statusChan chan SolutionStatus, so } } -func (s *SolutionRequest) persistRequestStatus(statusChan chan SolutionStatus, solutionStorage api.SolutionStorage, searchID string, dataset string, status string) error { +func (s *SolutionRequest) persistRequestStatus(statusChan chan SolutionStatus, + solutionStorage api.SolutionStorage, searchID string, dataset string, task []string, status string) error { // persist the updated state - err := solutionStorage.PersistRequest(searchID, dataset, status, time.Now()) + err := solutionStorage.PersistRequest(searchID, dataset, task, status, time.Now()) if err != nil { // notify of error - s.persistRequestError(statusChan, solutionStorage, searchID, dataset, err) + s.persistRequestError(statusChan, solutionStorage, searchID, dataset, task, err) return err } @@ -563,9 +577,8 @@ func describeSolution(client *compute.Client, initialSearchSolutionID string) (* func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorage api.SolutionStorage, dataStorage api.DataStorage, searchContext pipelineSearchContext) { - // update request status - err := s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, compute.RequestRunningStatus) + err := s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, searchContext.task, compute.RequestRunningStatus) if err != nil { s.finished <- err return @@ -615,9 +628,9 @@ func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorag // update request status if err != nil { - s.persistRequestError(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, err) + s.persistRequestError(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, searchContext.task, err) } else { - if err = s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, compute.RequestCompletedStatus); err != nil { + if err = s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, searchContext.task, compute.RequestCompletedStatus); err != nil { log.Errorf("failed to persist status %s for search %s", compute.RequestCompletedStatus, searchContext.searchID) } } @@ -631,6 +644,174 @@ func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorag s.finished <- nil } +func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage, + dataStorage api.DataStorage, client *compute.Client, datasetInputDir string, step *description.FullySpecifiedPipeline) { + log.Infof("dispatching segmentation pipeline") + + // create the backing data + err := s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, s.Dataset, s.Task, compute.RequestRunningStatus) + if err != nil { + s.finished <- err + return + } + + c := newStatusChannel() + + // run the pipeline + pipelineResult, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, false) + if err != nil { + s.finished <- err + return + } + + // add the solution to the request + // doing this after submission to have the solution id available! + s.addSolution(c) + s.persistSolution(c, solutionStorage, requestID, pipelineResult.SolutionID, "") + s.persistSolutionStatus(c, solutionStorage, requestID, pipelineResult.SolutionID, compute.SolutionPendingStatus) + s.persistSolutionStatus(c, solutionStorage, requestID, pipelineResult.SolutionID, compute.SolutionScoringStatus) + + // HACK: MAKE UP A SOLUTION SCORE!!! + err = solutionStorage.PersistSolutionScore(pipelineResult.SolutionID, util.F1Micro, 0.5) + if err != nil { + s.finished <- err + return + } + s.persistSolutionStatus(c, solutionStorage, requestID, pipelineResult.SolutionID, compute.SolutionProducingStatus) + + // update status and respond to client as needed + uuidGen, err := uuid.NewV4() + if err != nil { + s.finished <- errors.Wrapf(err, "unable to generate solution id") + return + } + resultID := uuidGen.String() + c <- SolutionStatus{ + RequestID: requestID, + SolutionID: pipelineResult.SolutionID, + ResultID: resultID, + Progress: compute.SolutionCompletedStatus, + Timestamp: time.Now(), + } + close(c) + + // get the grouping key since it makes up part of the filename + log.Infof("processing segmentation pipeline output") + dataset, err := metaStorage.FetchDataset(s.Dataset, true, true, false) + if err != nil { + s.finished <- err + return + } + + // HACK: INPUT FAKE RESULTS TO THE DB!!! + // FAKE RESULTS SHOULD JUST BE A CONSTANT! + uuidGen, err = uuid.NewV4() + if err != nil { + s.finished <- errors.Wrapf(err, "unable to generate produce request id") + return + } + produceRequestID := uuidGen.String() + + // HACK: CREATE FAKE RESULTS TO PERSIST AS THE ACTUAL RESULTS SHOULD NOT BE STORED IN THE DB!!! + resultOutputURI, err := createSegmentationResult(s.Dataset, pipelineResult.ResultURI, s.TargetFeature.HeaderName, metaStorage, dataStorage) + if err != nil { + s.finished <- err + return + } + + log.Infof("persisting results in URI '%s'", resultOutputURI) + err = s.persistSolutionResults(c, client, solutionStorage, dataStorage, requestID, dataset.ID, + dataset.StorageName, pipelineResult.SolutionID, pipelineResult.FittedSolutionID, produceRequestID, resultID, resultOutputURI) + if err != nil { + s.finished <- errors.Wrapf(err, "unable to persist solution result") + return + } + + log.Infof("segmentation pipeline processing complete") + + err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, s.Task, compute.RequestCompletedStatus) + if err != nil { + s.finished <- err + return + } + close(s.requestChannel) + s.finished <- nil +} + +func processSegmentation(s *SolutionRequest, client *compute.Client, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage, dataStorage api.DataStorage) error { + // create the fully specified pipeline + envConfig, err := env.LoadConfig() + if err != nil { + return err + } + + // fetch the source dataset + dataset, err := metaStorage.FetchDataset(s.Dataset, true, true, false) + if err != nil { + return nil + } + s.DatasetMetadata = dataset + variablesMap := api.MapVariables(dataset.Variables, func(v *model.Variable) string { return v.Key }) + + datasetInputDir := env.ResolvePath(dataset.Source, dataset.Folder) + + step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", s.TargetFeature, + envConfig.RemoteSensingNumJobs, envConfig.RemoteSensingGPUBatchSize) + if err != nil { + return err + } + + // need a request ID + uuidGen, err := uuid.NewV4() + if err != nil { + return err + } + requestID := uuidGen.String() + + // persist the request + err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, s.Task, compute.RequestPendingStatus) + if err != nil { + return err + } + + // store the request features - note that we are storing the original request filters, not the expanded + // list that was generated + // also note that augmented features should not be included + for _, v := range s.Filters.Variables { + var typ string + // ignore the index field + if v == model.D3MIndexFieldName { + continue + } else if variablesMap[v].HasRole(model.VarDistilRoleAugmented) { + continue + } + + if v == s.TargetFeature.Key { + // store target feature + typ = model.FeatureTypeTarget + } else { + // store training feature + typ = model.FeatureTypeTrain + } + err = solutionStorage.PersistRequestFeature(requestID, v, typ) + if err != nil { + return err + } + } + + // store the original request filters + // HACK: NO FILTERS SUPPORTED FOR SEGMENTATION! + err = solutionStorage.PersistRequestFilters(requestID, s.Filters) + if err != nil { + return err + } + + // dispatch it as if it were a model search + go dispatchSegmentation(s, requestID, solutionStorage, metaStorage, dataStorage, client, datasetInputDir, step) + + return nil +} + // PersistAndDispatch persists the solution request and dispatches it. func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage, dataStorage api.DataStorage) error { @@ -706,18 +887,6 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto } s.Filters = updatedFilters - // get the target - datasetInputDir := filteredDatasetPath - meta, err := serialization.ReadMetadata(path.Join(datasetInputDir, compute.D3MDataSchema)) - if err != nil { - return err - } - metaVars := meta.GetMainDataResource().Variables - targetVariable, err = findVariable(targetVariable.Key, metaVars) - if err != nil { - return err - } - if dataset.LearningDataset != "" { s.useParquet = true groupingVariableIndex = -1 @@ -728,11 +897,22 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto if err != nil { return err } - task, err := ResolveTask(dataStorage, dataset.StorageName, s.TargetFeature, trainingVariables) - if err != nil { - return err + + var task *Task + if len(s.Task) > 0 { + task = &Task{s.Task} + } else { + task, err = ResolveTask(dataStorage, dataset.StorageName, s.TargetFeature, trainingVariables) + if err != nil { + return err + } + s.Task = task.Task } - s.Task = task.Task + + if HasTaskType(task, compute.SegmentationTask) { + return processSegmentation(s, client, solutionStorage, metaStorage, dataStorage) + } + // check if TimestampSplitValue is not 0 if s.TimestampSplitValue > 0 { found := false @@ -751,6 +931,17 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto } } + // get the target + meta, err := serialization.ReadMetadata(path.Join(filteredDatasetPath, compute.D3MDataSchema)) + if err != nil { + return err + } + metaVars := meta.GetMainDataResource().Variables + targetVariable, err = findVariable(targetVariable.Key, metaVars) + if err != nil { + return err + } + // when dealing with categorical data we want to stratify stratify := model.IsCategorical(s.TargetFeature.Type) // create the splitter to use for the train / test split @@ -805,7 +996,7 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto } // persist the request - err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, compute.RequestPendingStatus) + err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, s.Task, compute.RequestPendingStatus) if err != nil { return err } @@ -845,8 +1036,9 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto searchContext := pipelineSearchContext{ searchID: requestID, dataset: dataset.ID, + task: s.Task, storageName: dataset.StorageName, - sourceDatasetURI: datasetInputDir, + sourceDatasetURI: filteredDatasetPath, trainDatasetURI: datasetPathTrain, testDatasetURI: datasetPathTest, produceDatasetURI: datasetPathTest, @@ -912,7 +1104,7 @@ type confidenceValue struct { row int } -func reformatResult(resultURI string) (string, error) { +func reformatResult(resultURI string, targetName string) (string, error) { // read data from original file dataReader := serialization.GetStorage(resultURI) data, err := dataReader.ReadData(resultURI) @@ -968,3 +1160,74 @@ func reformatResult(resultURI string) (string, error) { return filteredURI, nil } + +// isSegmentationOutput checks if a result is from an image segmentation pipeline. +// NOTE: returns false if it cannot confirm it is segmentation (ex: exception occurs)! +func isSegmentationOutput(resultURI string) bool { + result, err := result.ParseResultCSV(resultURI) + if err != nil { + return false + } + + // segmentation output has as header "d3mIndex,positive_mask" + return len(result[0]) == 2 && result[0][0].(string) == model.D3MIndexFieldName && result[0][1].(string) == "positive_mask" +} + +func createSegmentationResult(datasetID string, resultURI string, + targetName string, metaStorage api.MetadataStorage, dataStorage api.DataStorage) (string, error) { + log.Infof("processing segmentation pipeline output") + result, err := result.ParseResultCSV(resultURI) + if err != nil { + return "", err + } + + images, err := BuildSegmentationImage(result) + if err != nil { + return "", err + } + + // get the grouping key since it makes up part of the filename + dataset, err := metaStorage.FetchDataset(datasetID, true, true, false) + if err != nil { + return "", err + } + + var groupingKey *model.Variable + for _, v := range dataset.Variables { + if v.HasRole(model.VarDistilRoleGrouping) { + groupingKey = v + break + } + } + if groupingKey == nil { + return "", errors.Errorf("no grouping found to use for output filename") + } + + // get the d3m index -> grouping key mapping + mapping, err := api.BuildFieldMapping(dataset.ID, dataset.StorageName, model.D3MIndexFieldName, groupingKey.Key, dataStorage) + if err != nil { + return "", err + } + + imageOutputFolder := path.Join(env.GetResourcePath(), dataset.ID, "media") + for d3mIndex, imageBytes := range images { + imageFilename := path.Join(imageOutputFolder, fmt.Sprintf("%s-segmentation.png", mapping[d3mIndex])) + err = util.WriteFileWithDirs(imageFilename, imageBytes, os.ModePerm) + if err != nil { + return "", err + } + } + + resultOutput := []string{fmt.Sprintf("%s,%s,%s", model.D3MIndexFieldName, targetName, "confidence")} + for i := 1; i < len(result); i++ { + resultOutput = append(resultOutput, fmt.Sprintf("%s,%s,%d", result[i][0], "segmented", 1)) + } + resultOutputURI := fmt.Sprintf("%s-distil%s", resultURI[:len(resultURI)-4], resultURI[len(resultURI)-4:]) + log.Infof("writing distil formatted segmentation results to '%s'", resultOutputURI) + err = util.WriteFileWithDirs(resultOutputURI, []byte(strings.Join(resultOutput, "\n")), os.ModePerm) + if err != nil { + return "", err + } + + return resultOutputURI, nil +} diff --git a/api/compute/split.go b/api/compute/split.go index 4d90964b8..8e0030ba0 100644 --- a/api/compute/split.go +++ b/api/compute/split.go @@ -56,6 +56,10 @@ type basicSplitter struct { trainTestSplit float64 } +type copySplitter struct { + rowLimits rowLimits +} + type stratifiedSplitter struct { rowLimits rowLimits targetCol int @@ -204,6 +208,50 @@ func (b *basicSplitter) sample(data [][]string, maxRows int) [][]string { return output } +func (c *copySplitter) hash(schemaFile string, params ...interface{}) (uint64, error) { + // generate the hash from the params + hashStruct := struct { + Schema string + Copy bool + RowLimits rowLimits + Params []interface{} + }{ + Schema: schemaFile, + Copy: true, + RowLimits: c.rowLimits, + Params: params, + } + hash, err := hashstructure.Hash(hashStruct, nil) + if err != nil { + return 0, errors.Wrap(err, "failed to generate persisted data hash") + } + return hash, nil +} + +func (c *copySplitter) split(data [][]string) ([][]string, [][]string, error) { + log.Infof("splitting data using copy splitter...") + // create the output + outputTrain := [][]string{} + outputTest := [][]string{} + + // handle the header + inputData, outputTrain, outputTest := splitTrainTestHeader(data, outputTrain, outputTest, true) + + numTrainingRows := c.rowLimits.trainingRows(len(inputData)) + + // sample to meet row limit constraints + output := c.sample(inputData, numTrainingRows) + + return append(outputTrain, output...), append(outputTest, output...), nil +} + +func (c *copySplitter) sample(data [][]string, maxRows int) [][]string { + output := [][]string{} + output, _ = shuffleAndWrite(data[1:], -1, maxRows, 0, false, output, nil, float64(1)) + + return output +} + func (s *stratifiedSplitter) hash(schemaFile string, params ...interface{}) (uint64, error) { // generate the hash from the params hashStruct := struct { @@ -496,6 +544,10 @@ func createSplitter(taskType []string, targetFieldIndex int, groupingFieldIndex trainTestSplit: trainTestSplit, }, } + } else if task == compute.SegmentationTask { + return ©Splitter{ + rowLimits: limits, + } } } // if not null diff --git a/api/compute/task.go b/api/compute/task.go index db4543124..408f01363 100644 --- a/api/compute/task.go +++ b/api/compute/task.go @@ -88,7 +88,7 @@ func ResolveTask(storage api.DataStorage, datasetStorageName string, targetVaria if model.IsImage(feature.Type) { task = append(task, compute.ImageTask) } else if model.IsMultiBandImage(feature.Type) { - task = append(task, compute.RemoteSensingTask) + task = append(task, compute.RemoteSensingTask, compute.SegmentationTask) } else if model.IsTimeSeries(feature.Type) { task = append(task, compute.TimeSeriesTask) } @@ -110,7 +110,7 @@ func ResolveTask(storage api.DataStorage, datasetStorageName string, targetVaria task = append(task, compute.SemiSupervisedTask) } // If there are 3 labels (2 + empty), update this as a binary classification task - if len(targetCounts) == 2 { + if len(targetCounts) == 3 { task = append(task, compute.BinaryTask) } else { task = append(task, compute.MultiClassTask) diff --git a/api/env/config.go b/api/env/config.go index 167ef5a13..33c2539ad 100644 --- a/api/env/config.go +++ b/api/env/config.go @@ -76,6 +76,7 @@ type Config struct { PostgresPassword string `env:"PG_PASSWORD" envDefault:""` PostgresPort int `env:"PG_PORT" envDefault:"5432"` PostgresRandomSeed float64 `env:"PG_RANDOM_SEED" envDefault:"0.2"` + PostgresUpdate bool `env:"PG_UPDATE" envDefault:"false"` PostgresUser string `env:"PG_USER" envDefault:"distil"` PublicSubFolder string `env:"PUBLIC_SUBFOLDER" envDefault:"public"` RankingOutputPath string `env:"RANKING_OUTPUT_PATH" envDefault:"importance.json"` @@ -84,6 +85,7 @@ type Config struct { ResourceSubFolder string `env:"RESOURCE_SUBFOLDER" envDefault:"resources"` ShouldScaleImages bool `env:"SHOULD_SCALE_IMAGES" envDefault:"false"` // enables and disables image scaling SkipPreprocessing bool `env:"SKIP_PREPROCESSING" envDefault:"false"` + SegmentationEnabled bool `env:"SEGMENTATION_ENABLED" envDefault:"false"` SolutionComputeEndpoint string `env:"SOLUTION_COMPUTE_ENDPOINT" envDefault:"localhost:50051"` SolutionComputePullTimeout int `env:"SOLUTION_COMPUTE_PULL_TIMEOUT" envDefault:"60"` SolutionComputePullMax int `env:"SOLUTION_COMPUTE_PULL_MAX" envDefault:"10"` diff --git a/api/model/grouped_variables.go b/api/model/grouped_variables.go index 9da46392c..70560e63d 100644 --- a/api/model/grouped_variables.go +++ b/api/model/grouped_variables.go @@ -18,6 +18,8 @@ package model import ( "fmt" + "github.com/pkg/errors" + "github.com/uncharted-distil/distil-compute/model" log "github.com/unchartedsoftware/plog" ) @@ -171,6 +173,39 @@ func GetClusterColFromGrouping(group model.BaseGrouping) (string, bool) { return "", false } +// BuildFieldMapping builds a mapping from a source field to a target field. +func BuildFieldMapping(dsID string, dsStorageName string, sourceFieldName string, + targetFieldName string, dataStorage DataStorage) (map[string]string, error) { + filter := &FilterParams{Variables: []string{sourceFieldName, targetFieldName}} + + // pull back all rows for a group id + data, err := dataStorage.FetchData(dsID, dsStorageName, filter, true, nil) + if err != nil { + return nil, err + } + + // cycle through results to build the band mapping + targetFieldColumn, ok := data.Columns[targetFieldName] + if !ok { + return nil, errors.Errorf("'%s' column not found in stored data", targetFieldName) + } + targetFieldColumnIndex := targetFieldColumn.Index + sourceColumn, ok := data.Columns[sourceFieldName] + if !ok { + return nil, errors.Errorf("'%s' column not found in stored data", sourceFieldName) + } + sourceColumnIndex := sourceColumn.Index + + mapping := map[string]string{} + for _, r := range data.Values { + sourceData := fmt.Sprintf("%.0f", r[sourceColumnIndex].Value.(float64)) + fieldData := r[targetFieldColumnIndex].Value.(string) + mapping[sourceData] = fieldData + } + + return mapping, nil +} + // UpdateFilterKey updates the supplied filter key to point to a group-specific column, rather than relying on the group variable // name. func UpdateFilterKey(metaStore MetadataStorage, dataset string, dataMode DataMode, filter *model.Filter, variable *model.Variable) { diff --git a/api/model/model.go b/api/model/model.go index c1f10ea2c..5b13fb6d5 100644 --- a/api/model/model.go +++ b/api/model/model.go @@ -42,6 +42,7 @@ type ExportedModel struct { FittedSolutionID string `json:"fittedSolutionId"` DatasetID string `json:"datasetId"` DatasetName string `json:"datasetName"` + Task []string `json:"task"` Target *SolutionVariable `json:"target"` Variables []string `json:"variables"` VariableDetails []*SolutionVariable `json:"variableDetails"` @@ -52,6 +53,7 @@ type ExportedModel struct { type Request struct { RequestID string `json:"requestId"` Dataset string `json:"dataset"` + Task []string `json:"task"` Progress string `json:"progress"` CreatedTime time.Time `json:"timestamp"` LastUpdatedTime time.Time `json:"lastUpdatedTime"` diff --git a/api/model/storage.go b/api/model/storage.go index 42745c544..d735460f0 100644 --- a/api/model/storage.go +++ b/api/model/storage.go @@ -138,7 +138,7 @@ type SolutionStorageCtor func() (SolutionStorage, error) // solution storage. type SolutionStorage interface { PersistPrediction(requestID string, dataset string, target string, fittedSolutionID string, progress string, createdTime time.Time) error - PersistRequest(requestID string, dataset string, progress string, createdTime time.Time) error + PersistRequest(requestID string, dataset string, task []string, progress string, createdTime time.Time) error PersistRequestFeature(requestID string, featureName string, featureType string) error PersistRequestFilters(requestID string, filters *FilterParams) error PersistSolution(requestID string, solutionID string, explainedSolutionID string, createdTime time.Time) error diff --git a/api/model/storage/elastic/model.go b/api/model/storage/elastic/model.go index c8d42f939..76667b35e 100644 --- a/api/model/storage/elastic/model.go +++ b/api/model/storage/elastic/model.go @@ -122,6 +122,10 @@ func (s *Storage) parseModels(res *elastic.SearchResult, includeDeleted bool) ([ if !ok { return nil, errors.New("failed to parse the dataset id") } + + // get the task + tasks, _ := json.StringArray(src, "task") + // extract the target targetInfo, ok := json.Get(src, "target") if !ok { @@ -153,6 +157,7 @@ func (s *Storage) parseModels(res *elastic.SearchResult, includeDeleted bool) ([ FittedSolutionID: fittedSolutionID, DatasetID: datasetID, DatasetName: name, + Task: tasks, Target: target, Variables: variables, VariableDetails: variableDetails, diff --git a/api/model/storage/postgres/dataset.go b/api/model/storage/postgres/dataset.go index 88e1b4c5a..27df1d8d6 100644 --- a/api/model/storage/postgres/dataset.go +++ b/api/model/storage/postgres/dataset.go @@ -357,7 +357,7 @@ func (s *Storage) FetchDataset(dataset string, storageName string, filteredVars := []*model.Variable{} selectedVars := map[string]bool{} - if filterParams != nil { + if limitSelectedFields && filterParams != nil { for _, v := range filterParams.Variables { selectedVars[v] = true } diff --git a/api/model/storage/postgres/request.go b/api/model/storage/postgres/request.go index df4aa0f68..fe7e11029 100644 --- a/api/model/storage/postgres/request.go +++ b/api/model/storage/postgres/request.go @@ -28,10 +28,10 @@ import ( ) // PersistRequest persists a request to Postgres. -func (s *Storage) PersistRequest(requestID string, dataset string, progress string, createdTime time.Time) error { - sql := fmt.Sprintf("INSERT INTO %s (request_id, dataset, progress, created_time, last_updated_time) VALUES ($1, $2, $3, $4, $4);", postgres.RequestTableName) +func (s *Storage) PersistRequest(requestID string, dataset string, task []string, progress string, createdTime time.Time) error { + sql := fmt.Sprintf("INSERT INTO %s (request_id, dataset, task, progress, created_time, last_updated_time) VALUES ($1, $2, $3, $4, $5, $5);", postgres.RequestTableName) - _, err := s.client.Exec(sql, requestID, dataset, progress, createdTime) + _, err := s.client.Exec(sql, requestID, dataset, strings.Join(task, ","), progress, createdTime) return errors.Wrapf(err, "failed to persist request to PostGres") } @@ -93,7 +93,7 @@ func (s *Storage) PersistRequestFilters(requestID string, filters *api.FilterPar // FetchRequest pulls request information from Postgres. func (s *Storage) FetchRequest(requestID string) (*api.Request, error) { - sql := fmt.Sprintf("SELECT request_id, dataset, progress, created_time, last_updated_time FROM %s WHERE request_id = $1 ORDER BY created_time desc LIMIT 1;", postgres.RequestTableName) + sql := fmt.Sprintf("SELECT request_id, dataset, task, progress, created_time, last_updated_time FROM %s WHERE request_id = $1 ORDER BY created_time desc LIMIT 1;", postgres.RequestTableName) rows, err := s.client.Query(sql, requestID) if err != nil { @@ -114,7 +114,7 @@ func (s *Storage) FetchRequest(requestID string) (*api.Request, error) { // FetchRequestByResultUUID pulls request information from Postgres using // a result UUID. func (s *Storage) FetchRequestByResultUUID(resultUUID string) (*api.Request, error) { - sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.progress, req.created_time, req.last_updated_time "+ + sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.task, req.progress, req.created_time, req.last_updated_time "+ "FROM %s as req INNER JOIN %s as sol ON req.request_id = sol.request_id INNER JOIN %s as sol_res ON sol.solution_id = sol_res.solution_id "+ "WHERE sol_res.result_uuid = $1;", postgres.RequestTableName, postgres.SolutionTableName, postgres.SolutionResultTableName) @@ -139,7 +139,7 @@ func (s *Storage) FetchRequestByResultUUID(resultUUID string) (*api.Request, err // FetchRequestBySolutionID pulls request information from Postgres using // a solution ID. func (s *Storage) FetchRequestBySolutionID(solutionID string) (*api.Request, error) { - sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.progress, req.created_time, req.last_updated_time "+ + sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.task, req.progress, req.created_time, req.last_updated_time "+ "FROM %s as req INNER JOIN %s as sol ON req.request_id = sol.request_id "+ "WHERE sol.solution_id = $1;", postgres.RequestTableName, postgres.SolutionTableName) @@ -164,7 +164,7 @@ func (s *Storage) FetchRequestBySolutionID(solutionID string) (*api.Request, err // FetchRequestByFittedSolutionID pulls request information from Postgres using // a fitted solution ID. func (s *Storage) FetchRequestByFittedSolutionID(fittedSolutionID string) (*api.Request, error) { - sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.progress, req.created_time, req.last_updated_time "+ + sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.task, req.progress, req.created_time, req.last_updated_time "+ "FROM %s as req INNER JOIN %s as sol ON req.request_id = sol.request_id INNER JOIN %s sr on sr.solution_id = sol.solution_id "+ "WHERE sr.fitted_solution_id = $1;", postgres.RequestTableName, postgres.SolutionTableName, postgres.SolutionResultTableName) @@ -189,11 +189,12 @@ func (s *Storage) FetchRequestByFittedSolutionID(fittedSolutionID string) (*api. func (s *Storage) loadRequest(rows pgx.Rows) (*api.Request, error) { var requestID string var dataset string + var task string var progress string var createdTime time.Time var lastUpdatedTime time.Time - err := rows.Scan(&requestID, &dataset, &progress, &createdTime, &lastUpdatedTime) + err := rows.Scan(&requestID, &dataset, &task, &progress, &createdTime, &lastUpdatedTime) if err != nil { return nil, errors.Wrap(err, "Unable to parse request from Postgres") } @@ -211,6 +212,7 @@ func (s *Storage) loadRequest(rows pgx.Rows) (*api.Request, error) { return &api.Request{ RequestID: requestID, Dataset: dataset, + Task: strings.Split(task, ","), Progress: progress, CreatedTime: createdTime, LastUpdatedTime: lastUpdatedTime, @@ -350,7 +352,7 @@ func (s *Storage) FetchRequestFilters(requestID string, features []*api.Feature) // FetchRequestByDatasetTarget pulls requests associated with a given dataset and target from postgres. func (s *Storage) FetchRequestByDatasetTarget(dataset string, target string) ([]*api.Request, error) { // get the solution ids - sql := fmt.Sprintf("SELECT DISTINCT ON(request.request_id) request.request_id, request.dataset, request.progress, request.created_time, request.last_updated_time "+ + sql := fmt.Sprintf("SELECT DISTINCT ON(request.request_id) request.request_id, request.dataset, request.task, request.progress, request.created_time, request.last_updated_time "+ "FROM %s request INNER JOIN %s rf ON request.request_id = rf.request_id "+ "INNER JOIN %s solution ON request.request_id = solution.request_id", postgres.RequestTableName, postgres.RequestFeatureTableName, postgres.SolutionTableName) diff --git a/api/model/storage/postgres/storage.go b/api/model/storage/postgres/storage.go index 26bf9e2a0..0c754fe3d 100644 --- a/api/model/storage/postgres/storage.go +++ b/api/model/storage/postgres/storage.go @@ -22,6 +22,7 @@ import ( "github.com/uncharted-distil/distil-compute/model" log "github.com/unchartedsoftware/plog" + "github.com/uncharted-distil/distil/api/env" api "github.com/uncharted-distil/distil/api/model" "github.com/uncharted-distil/distil/api/postgres" ) @@ -46,10 +47,48 @@ func NewDataStorage(clientCtor postgres.ClientCtor, batchClientCtor postgres.Cli } // NewSolutionStorage returns a constructor for a solution storage. -func NewSolutionStorage(clientCtor postgres.ClientCtor, metadataCtor api.MetadataStorageCtor) api.SolutionStorageCtor { - return func() (api.SolutionStorage, error) { - return newStorage(clientCtor, nil, metadataCtor) +func NewSolutionStorage(clientCtor postgres.ClientCtor, metadataCtor api.MetadataStorageCtor, updateStorage bool) (api.SolutionStorageCtor, error) { + if updateStorage { + config, err := env.LoadConfig() + if err != nil { + return nil, err + } + + // Connect to the database. + postgresConfig := &postgres.Config{ + Password: config.PostgresPassword, + User: config.PostgresUser, + Database: config.PostgresDatabase, + Host: config.PostgresHost, + Port: config.PostgresPort, + PostgresLogLevel: "error", + } + pg, err := postgres.NewDatabase(postgresConfig, false) + if err != nil { + return nil, errors.Wrapf(err, "unable to initialize a new database") + } + + latestSchema, err := pg.IsLatestSchema() + if err != nil { + return nil, err + } + + if !latestSchema { + err = pg.UpdateSchema() + if err != nil { + return nil, err + } + } } + return func() (api.SolutionStorage, error) { + storage, err := newStorage(clientCtor, nil, metadataCtor) + + if err != nil { + return nil, err + } + + return storage, nil + }, nil } func newStorage(clientCtor postgres.ClientCtor, batchClientCtor postgres.ClientCtor, metadataCtor api.MetadataStorageCtor) (*Storage, error) { diff --git a/api/postgres/postgres.go b/api/postgres/postgres.go index 6aaaaa385..1ca75f397 100644 --- a/api/postgres/postgres.go +++ b/api/postgres/postgres.go @@ -80,9 +80,18 @@ const ( // WordStemTableName is the name of the table for the word stems. WordStemTableName = "word_stem" + configTableName = "config" + version = "0.1" + + configTableCreationSQL = `CREATE TABLE %s ( + key text, + value text + );` + requestTableCreationSQL = `CREATE TABLE %s ( request_id text, dataset varchar(200), + task text, progress varchar(40), created_time timestamp, last_updated_time timestamp @@ -165,6 +174,8 @@ const ( resultTableSuffix = "_result" variableTableSuffix = "_variable" explainTableSuffix = "_explain" + + distilSchemaKey = "distil-schema-version" ) var ( @@ -229,6 +240,98 @@ func NewDatabase(config *Config, batch bool) (*Database, error) { return database, nil } +// IsLatestSchema returns true if the solution metadata schema matches the latest. +func (d *Database) IsLatestSchema() (bool, error) { + // check for the presence of the config table + log.Infof("verifying that postgres is using the latest schema") + configExists, err := d.tableExists(configTableName) + if err != nil { + return false, err + } + + // if the config table isnt there, then it isnt the latest + if !configExists { + log.Infof("postgres not using latest schema as the config table does not exist") + return false, nil + } + + // check the version stored in the config table against the latest version + config, err := d.loadConfig() + if err != nil { + return false, err + } + + log.Infof("postgres schema version %s and the latest version is %s", config[distilSchemaKey], version) + return config[distilSchemaKey] == version, nil +} + +// UpdateSchema updates the metadata schema and stores the version to the database. +func (d *Database) UpdateSchema() error { + // recreate metadata tables + err := d.CreateSolutionMetadataTables() + if err != nil { + return err + } + + // write the version to the config table + configToStore := map[string]string{distilSchemaKey: version} + for k, v := range configToStore { + sql := fmt.Sprintf("INSERT INTO %s (key, value) VALUES ($1, $2);", configTableName) + _, err = d.Client.Exec(sql, k, v) + if err != nil { + return errors.Wrapf(err, "unable to store postgres config") + } + } + + return nil +} + +func (d *Database) loadConfig() (map[string]string, error) { + log.Infof("reading postgres config") + sql := fmt.Sprintf("SELECT key, value FROM %s;", configTableName) + + rows, err := d.Client.Query(sql) + if err != nil { + return nil, errors.Wrapf(err, "unable to query postgres config") + } + defer rows.Close() + + configData := map[string]string{} + for rows.Next() { + var key string + var value string + + err = rows.Scan(&key, &value) + if err != nil { + return nil, errors.Wrapf(err, "unable to parse postgres config") + } + configData[key] = value + } + + log.Infof("postgres config: %v", configData) + + return configData, nil +} + +func (d *Database) tableExists(name string) (bool, error) { + sql := "SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1);" + + rows, err := d.Client.Query(sql, name) + if err != nil { + return false, errors.Wrapf(err, "unable to verify if a table exists") + } + defer rows.Close() + + rows.Next() + var exists bool + err = rows.Scan(&exists) + if err != nil { + return false, errors.Wrap(err, "unable to parse table existance result") + } + + return exists, nil +} + // CreateSolutionMetadataTables creates an empty table for the solution results. func (d *Database) CreateSolutionMetadataTables() error { // Create the solution tables. @@ -240,6 +343,12 @@ func (d *Database) CreateSolutionMetadataTables() error { return errors.Wrap(err, "failed to drop table") } + _ = d.DropTable(configTableName) + _, err = d.Client.Exec(fmt.Sprintf(configTableCreationSQL, configTableName)) + if err != nil { + return errors.Wrap(err, "failed to drop table") + } + _ = d.DropTable(RequestTableName) _, err = d.Client.Exec(fmt.Sprintf(requestTableCreationSQL, RequestTableName)) if err != nil { diff --git a/api/routes/image_pack.go b/api/routes/image_pack.go index d545dfc32..5d4799361 100644 --- a/api/routes/image_pack.go +++ b/api/routes/image_pack.go @@ -80,7 +80,9 @@ func MultiBandImagePackHandler(ctor api.MetadataStorageCtor, dataCtor api.DataSt funcPointer := getImages optramMap := map[string]imagery.OptramEdges{} precision := 0 - if params.Band != "" { + if params.Band == imagery.Segmentation { + funcPointer = getSegmentationImages + } else if params.Band != "" { // if band is not empty then get multiBandImages funcPointer = getMultiBandImages if params.Band == imagery.OPTRAM { @@ -200,6 +202,7 @@ func getImages(imagePackRequest *ImagePackRequest, _ map[string]imagery.OptramEd } result <- chanStruct{data: temp, IDs: IDs, errorIDs: errorIDs} } + func getMultiBandImages(multiBandPackRequest *ImagePackRequest, optramMap map[string]imagery.OptramEdges, precision int, threadID int, numThreads int, result chan chanStruct, ctor api.MetadataStorageCtor, dataCtor api.DataStorageCtor) { temp := [][]byte{} IDs := []string{} @@ -278,6 +281,46 @@ func getMultiBandImages(multiBandPackRequest *ImagePackRequest, optramMap map[st result <- chanStruct{data: temp, IDs: IDs, errorIDs: errorIDs} } +func getSegmentationImages(multiBandPackRequest *ImagePackRequest, optramMap map[string]imagery.OptramEdges, precision int, threadID int, numThreads int, result chan chanStruct, ctor api.MetadataStorageCtor, dataCtor api.DataStorageCtor) { + temp := [][]byte{} + IDs := []string{} + errorIDs := []string{} + // get common storage + storage, err := ctor() + if err != nil { + log.Error(err) + return + } + + res, err := storage.FetchDataset(multiBandPackRequest.Dataset, false, false, false) + if err != nil { + log.Error(err) + return + } + + sourcePath := path.Join(env.GetResourcePath(), res.ID, "media") + + // loop through image info + for i := threadID; i < len(multiBandPackRequest.ImageIDs); i += numThreads { + imageID := multiBandPackRequest.ImageIDs[i] + img, err := getSegmentationImage(imageID, sourcePath, true, ThumbnailDimensions) + if err != nil { + handleThreadError(&errorIDs, &imageID, &err) + continue + } + + imageBytes, err := imagery.ImageToJPEG(img) + if err != nil { + handleThreadError(&errorIDs, &imageID, &err) + continue + } + temp = append(temp, imageBytes) + IDs = append(IDs, imageID) + } + + result <- chanStruct{data: temp, IDs: IDs, errorIDs: errorIDs} +} + func handleThreadError(errorIDs *[]string, imageID *string, err *error) { *errorIDs = append(*errorIDs, *imageID) log.Error(*err) diff --git a/api/routes/multiband_image.go b/api/routes/multiband_image.go index 2f2c54334..8918ede56 100644 --- a/api/routes/multiband_image.go +++ b/api/routes/multiband_image.go @@ -16,13 +16,18 @@ package routes import ( + "bytes" "encoding/json" "fmt" + "image" + "image/draw" + "io/ioutil" "net/http" "path" "strconv" "strings" + "github.com/nfnt/resize" "github.com/pkg/errors" "github.com/uncharted-distil/distil-compute/metadata" "github.com/uncharted-distil/distil-compute/model" @@ -69,66 +74,77 @@ func MultiBandImageHandler(ctor api.MetadataStorageCtor, dataCtor api.DataStorag handleError(w, err) return } - res, err := storage.FetchDataset(dataset, false, false, false) if err != nil { handleError(w, err) return } - sourcePath := env.ResolvePath(res.Source, res.Folder) + options := imagery.Options{Gain: 2.5, Gamma: 2.2, GainL: 1.0, Scale: false} // default options for color correction - // need to read the dataset doc to determine the path to the data resource - metaDisk, err := metadata.LoadMetadataFromOriginalSchema(path.Join(sourcePath, compute.D3MDataSchema), false) - if err != nil { - handleError(w, err) - return - } - for _, dr := range metaDisk.DataResources { - if dr.IsCollection && dr.ResType == model.ResTypeImage { - sourcePath = model.GetResourcePathFromFolder(sourcePath, dr) - break + var img *image.RGBA + if bandCombo == imagery.Segmentation { + sourcePath := path.Join(env.GetResourcePath(), res.ID, "media") + img, err = getSegmentationImage(imageID, sourcePath, false, 0) + if err != nil { + handleError(w, err) + return } - } - options := imagery.Options{Gain: 2.5, Gamma: 2.2, GainL: 1.0, Scale: false} // default options for color correction - if paramOption != "" { - err := json.Unmarshal([]byte(paramOption), &options) + } else { + sourcePath := env.ResolvePath(res.Source, res.Folder) + + // need to read the dataset doc to determine the path to the data resource + metaDisk, err := metadata.LoadMetadataFromOriginalSchema(path.Join(sourcePath, compute.D3MDataSchema), false) if err != nil { handleError(w, err) return } - } - if isThumbnail { - imageScale = imagery.ImageScale{Width: ThumbnailDimensions, Height: ThumbnailDimensions} - // if thumbnail scale should be 0 - options.Scale = false - } + for _, dr := range metaDisk.DataResources { + if dr.IsCollection && dr.ResType == model.ResTypeImage { + sourcePath = model.GetResourcePathFromFolder(sourcePath, dr) + break + } + } + if paramOption != "" { + err := json.Unmarshal([]byte(paramOption), &options) + if err != nil { + handleError(w, err) + return + } + } + if isThumbnail { + imageScale = imagery.ImageScale{Width: ThumbnailDimensions, Height: ThumbnailDimensions} + // if thumbnail scale should be 0 + options.Scale = false + } - // need to get the band -> filename from the data - bandMapping, err := getBandMapping(res, []string{imageID}, dataStorage) - if err != nil { - handleError(w, err) - return - } - var optramMap map[string]imagery.OptramEdges - optramPath := "" - edge := imagery.OptramEdges{} - precision := 0 - if bandCombo == imagery.OPTRAM { - optramPath = strings.Join([]string{env.ResolvePath(res.Source, res.Folder), imagery.OPTRAMJSONFile}, "/") - optramMap, precision, err = imagery.ReadOptramFile(optramPath) + // need to get the band -> filename from the data + bandMapping, err := getBandMapping(res, []string{imageID}, dataStorage) if err != nil { handleError(w, err) return } - geoHash := imagery.ParseGeoHashFromID(imageID, precision) - edge = optramMap[geoHash] - } + var optramMap map[string]imagery.OptramEdges + optramPath := "" + edge := imagery.OptramEdges{} + precision := 0 + if bandCombo == imagery.OPTRAM { + optramPath = strings.Join([]string{env.ResolvePath(res.Source, res.Folder), imagery.OPTRAMJSONFile}, "/") + optramMap, precision, err = imagery.ReadOptramFile(optramPath) + if err != nil { + handleError(w, err) + return + } + geoHash := imagery.ParseGeoHashFromID(imageID, precision) + edge = optramMap[geoHash] + } - img, err := imagery.ImageFromCombination(sourcePath, bandMapping[imageID], bandCombo, imageScale, &edge, ramp, options) - if err != nil { - handleError(w, err) - return + img, err = imagery.ImageFromCombination(sourcePath, bandMapping[imageID], bandCombo, imageScale, &edge, ramp, options) + if err != nil { + handleError(w, err) + return + } } + if options.Scale && config.ShouldScaleImages { img = c_util.UpscaleImage(img, c_util.GetModelType(config.ModelType)) } @@ -145,6 +161,32 @@ func MultiBandImageHandler(ctor api.MetadataStorageCtor, dataCtor api.DataStorag } } +func getSegmentationImage(imageID string, sourcePath string, thumbnail bool, dimensions int) (*image.RGBA, error) { + data, err := ioutil.ReadFile(path.Join(sourcePath, fmt.Sprintf("%s-segmentation.png", imageID))) + if err != nil { + return nil, errors.Wrapf(err, "unable to read segmentation image") + } + + img, _, err := image.Decode(bytes.NewReader(data)) + if err != nil { + return nil, errors.Wrapf(err, "unable to decode segmentation image") + } + dimensionsY := dimensions + dimensionsX := dimensions + if thumbnail { + img = resize.Thumbnail(uint(dimensionsX), uint(dimensionsY), img, resize.Lanczos3) + } else { + size := img.Bounds().Size() + dimensionsY = size.X + dimensionsX = size.Y + } + + rgbaImg := image.NewRGBA(image.Rect(0, 0, dimensionsX, dimensionsY)) + draw.Draw(rgbaImg, image.Rect(0, 0, dimensionsX, dimensionsY), img, img.Bounds().Min, draw.Src) + + return rgbaImg, nil +} + func getBandMapping(ds *api.Dataset, groupKeys []string, dataStorage api.DataStorage) (map[string]map[string]string, error) { // build a filter to only include rows matching a group id var groupingCol string diff --git a/api/routes/segmentation.go b/api/routes/segmentation.go new file mode 100644 index 000000000..93e2d11cf --- /dev/null +++ b/api/routes/segmentation.go @@ -0,0 +1,55 @@ +package routes + +import ( + "net/http" + + "github.com/pkg/errors" + "goji.io/v3/pat" + + "github.com/uncharted-distil/distil/api/env" + api "github.com/uncharted-distil/distil/api/model" + "github.com/uncharted-distil/distil/api/task" +) + +// SegmentationHandler will segment the specified remote sensing dataset. +func SegmentationHandler(metaCtor api.MetadataStorageCtor, dataCtor api.DataStorageCtor, config env.Config) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + // get dataset name + dataset := pat.Param(r, "dataset") + // get variable name + variable := pat.Param(r, "variable") + + metaStorage, err := metaCtor() + if err != nil { + handleError(w, err) + return + } + + dataStorage, err := dataCtor() + if err != nil { + handleError(w, err) + return + } + + ds, err := metaStorage.FetchDataset(dataset, false, false, false) + if err != nil { + handleError(w, err) + return + } + + outputURI, err := task.Segment(ds, dataStorage, variable) + if err != nil { + handleError(w, errors.Wrap(err, "unable segment dataset")) + return + } + + // marshal output into JSON + err = handleJSON(w, map[string]interface{}{ + "uri": outputURI, + }) + if err != nil { + handleError(w, errors.Wrap(err, "unable marshal clustering result into JSON")) + return + } + } +} diff --git a/api/task/cleaning.go b/api/task/cleaning.go index ec9e86542..7b2ffe5ba 100644 --- a/api/task/cleaning.go +++ b/api/task/cleaning.go @@ -59,7 +59,7 @@ func Clean(schemaFile string, dataset string, params *IngestParams, config *Inge } // create & submit the solution request - pip, err := description.CreateDataCleaningPipeline("Mary Poppins", "", vars) + pip, err := description.CreateDataCleaningPipeline("Mary Poppins", "", vars, true) if err != nil { return "", errors.Wrap(err, "unable to create format pipeline") } diff --git a/api/task/dataset.go b/api/task/dataset.go index bb6793bd1..97fd2c4a7 100644 --- a/api/task/dataset.go +++ b/api/task/dataset.go @@ -129,7 +129,40 @@ func CopyDiskDataset(existingURI string, newURI string, newDatasetID string, new // it to disk in D3M dataset format. func ExportDataset(dataset string, metaStorage api.MetadataStorage, dataStorage api.DataStorage, filterParams *api.FilterParams) (string, string, error) { // TODO: most likely need to either get a unique folder name for output or error if already exists - return exportDiskDataset(dataset, dataset, env.ResolvePath(metadata.Augmented, dataset), metaStorage, dataStorage, false, filterParams) + datasetID, datasetPath, err := exportDiskDataset(dataset, dataset, env.ResolvePath(metadata.Augmented, dataset), metaStorage, dataStorage, false, filterParams) + if err != nil { + return "", "", err + } + + // update the metadata stored to have the index reflect what is on disk + metaDisk, err := serialization.ReadMetadata(path.Join(datasetPath, compute.D3MDataSchema)) + if err != nil { + return "", "", err + } + + metaStored, err := metaStorage.FetchDataset(datasetID, true, true, true) + if err != nil { + return "", "", err + } + + diskVars := api.MapVariables(metaDisk.GetMainDataResource().Variables, func(variable *model.Variable) string { return variable.Key }) + notIncludedCount := 0 + for _, v := range metaStored.Variables { + vDisk := diskVars[v.Key] + if vDisk == nil { + // variable not in disk dataset so arbitrarily give it an index > # of variables on disk + v.Index = len(diskVars) + notIncludedCount + notIncludedCount++ + } else { + v.Index = vDisk.Index + } + } + err = metaStorage.UpdateDataset(metaStored) + if err != nil { + return "", "", err + } + + return datasetID, datasetPath, nil } // CreateDatasetFromResult creates a new dataset based on a result set & the input diff --git a/api/task/pipelines.go b/api/task/pipelines.go index 51fe3f092..0e56498d8 100644 --- a/api/task/pipelines.go +++ b/api/task/pipelines.go @@ -62,7 +62,11 @@ func SetClient(computeClient *compute.Client) { } func submitPipeline(datasets []string, step *description.FullySpecifiedPipeline, shouldCache bool) (string, error) { - return sr.SubmitPipeline(client, datasets, nil, nil, step, nil, shouldCache) + result, err := sr.SubmitPipeline(client, datasets, nil, nil, step, nil, shouldCache) + if err != nil { + return "", err + } + return result.ResultURI, nil } func getD3MIndexField(dr *model.DataResource) int { diff --git a/api/task/prediction.go b/api/task/prediction.go index 4813dd73b..e7fa48420 100644 --- a/api/task/prediction.go +++ b/api/task/prediction.go @@ -206,11 +206,12 @@ type PredictParams struct { DatasetConstructor DatasetConstructor OutputPath string IndexFields []string + Task *comp.Task Target *model.Variable MetaStorage api.MetadataStorage DataStorage api.DataStorage SolutionStorage api.SolutionStorage - ModelStorage api.ExportedModelStorage + ExportedModel *api.ExportedModel IngestConfig *IngestTaskConfig Config *env.Config } @@ -695,12 +696,8 @@ func Predict(params *PredictParams) (string, error) { // Ensure the ta2 has fitted solution loaded. If the model wasn't saved, it should be available // as part of the session. - exportedModel, err := params.ModelStorage.FetchModelByID(params.FittedSolutionID) - if err != nil { - return "", err - } - if exportedModel != nil { - _, err = LoadFittedSolution(exportedModel.FilePath, params.SolutionStorage, params.MetaStorage) + if params.ExportedModel != nil { + _, err = LoadFittedSolution(params.ExportedModel.FilePath, params.SolutionStorage, params.MetaStorage) if err != nil { return "", err } @@ -708,7 +705,8 @@ func Predict(params *PredictParams) (string, error) { // submit the new dataset for predictions log.Infof("generating predictions using data found at '%s'", params.SchemaPath) - predictionResult, err := comp.GeneratePredictions(params.SchemaPath, solution.SolutionID, params.FittedSolutionID, client) + predictionResult, err := comp.GeneratePredictions(params.Dataset, params.SchemaPath, + solution.SolutionID, params.FittedSolutionID, params.Task, params.Target.HeaderName, params.MetaStorage, params.DataStorage, client) if err != nil { return "", err } diff --git a/api/task/segment.go b/api/task/segment.go new file mode 100644 index 000000000..6dbd8a582 --- /dev/null +++ b/api/task/segment.go @@ -0,0 +1,122 @@ +// +// Copyright © 2021 Uncharted Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "fmt" + "os" + "path" + "strconv" + + "github.com/pkg/errors" + + "github.com/uncharted-distil/distil-compute/model" + "github.com/uncharted-distil/distil-compute/primitive/compute/description" + "github.com/uncharted-distil/distil-compute/primitive/compute/result" + "github.com/uncharted-distil/distil/api/env" + api "github.com/uncharted-distil/distil/api/model" + "github.com/uncharted-distil/distil/api/util" + "github.com/uncharted-distil/distil/api/util/imagery" +) + +// Segment segments an image into separate parts. +func Segment(ds *api.Dataset, dataStorage api.DataStorage, variableName string) (string, error) { + envConfig, err := env.LoadConfig() + if err != nil { + return "", err + } + + datasetInputDir := env.ResolvePath(ds.Source, ds.Folder) + + var variable *model.Variable + for _, v := range ds.Variables { + if v.Key == variableName { + variable = v + break + } + } + + step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", variable, + envConfig.RemoteSensingNumJobs, envConfig.RemoteSensingGPUBatchSize) + if err != nil { + return "", err + } + + resultURI, err := submitPipeline([]string{datasetInputDir}, step, true) + if err != nil { + return "", err + } + + // read the file and parse the output mask + result, err := result.ParseResultCSV(resultURI) + if err != nil { + return "", err + } + + // need to pull the data to properly map d3m index to expected file names + // filenames should be "groupid-segmentation.png" for now + // TODO: may need to build the grouping key from multiple fields when moving away from test + var groupingKey *model.Variable + for _, v := range ds.Variables { + if v.HasRole(model.VarDistilRoleGrouping) { + groupingKey = v + break + } + } + if groupingKey == nil { + return "", errors.Errorf("no grouping found to use for output filename") + } + mapping, err := api.BuildFieldMapping(ds.ID, ds.StorageName, model.D3MIndexFieldName, groupingKey.Key, dataStorage) + if err != nil { + return "", err + } + + // need to output all the masks as images + imageOutputFolder := path.Join(env.GetResourcePath(), ds.ID, "media") + for _, r := range result[1:] { + // create the image that captures the mask + d3mIndex := r[0].(string) + rawMask := r[1].([]interface{}) + rawFloats := make([][]float64, len(rawMask)) + for i, f := range rawMask { + dataF := f.([]interface{}) + nestedFloats := make([]float64, len(dataF)) + for j, nf := range dataF { + fp, err := strconv.ParseFloat(nf.(string), 64) + if err != nil { + return "", errors.Wrapf(err, "unable to parse mask") + } + nestedFloats[j] = fp + } + rawFloats[i] = nestedFloats + } + + filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(255)) + imageBytes, err := imagery.ImageToPNG(filter) + if err != nil { + return "", err + } + + // write the image to disk using a basic naming convention + imageFilename := path.Join(imageOutputFolder, fmt.Sprintf("%s-segmentation.png", mapping[d3mIndex])) + err = util.WriteFileWithDirs(imageFilename, imageBytes, os.ModePerm) + if err != nil { + return "", err + } + } + + return "", nil +} diff --git a/api/task/solution.go b/api/task/solution.go index 2b39e891c..dcd4c3cca 100644 --- a/api/task/solution.go +++ b/api/task/solution.go @@ -88,6 +88,7 @@ func SaveFittedSolution(fittedSolutionID string, modelName string, modelDescript FittedSolutionID: fittedSolutionID, DatasetID: request.Dataset, DatasetName: dataset.Name, + Task: request.Task, Variables: vars, VariableDetails: varDetails, Target: target, diff --git a/api/util/imagery/imagery.go b/api/util/imagery/imagery.go index e1905e1ea..314103a90 100644 --- a/api/util/imagery/imagery.go +++ b/api/util/imagery/imagery.go @@ -71,6 +71,9 @@ const ( // AtmosphericRemoval identifies a band mapping that displays an image in near true color with atmoshperic effects reduced. AtmosphericRemoval = "atmospheric_removal" + // Segmentation identifies a placeholder band mapping to display image segmentation output. + Segmentation = "segmentation" + // ShortwaveInfrared identifies a band mapping that displays an image in shortwave infrared. ShortwaveInfrared = "shortwave_infrared" @@ -183,6 +186,7 @@ var ( ) func init() { + config, _ := env.LoadConfig() // initialize the band combination structures - needs to be done in init so that referenced color ramps are built SentinelBandCombinations = map[string]*BandCombination{ NaturalColors1: {NaturalColors1, "Natural Colors", []string{"b04", "b03", "b02"}, nil, nil, false}, @@ -204,6 +208,10 @@ func init() { RSWIR: {RSWIR, "Red and Shortwave Infrared", []string{"b04", "b11"}, BrownYellowBlueRamp, NormalizingTransform, false}, OPTRAM: {OPTRAM, "OPTRAM", []string{"b08", "b04", "b12"}, RedYellowGreenRamp, OptramTransform, false}, } + + if config.SegmentationEnabled { + SentinelBandCombinations[Segmentation] = &BandCombination{Segmentation, "Segmentation", []string{}, nil, nil, false} + } } // Initialize sets up the necessary structures for imagery processing. diff --git a/api/ws/pipeline.go b/api/ws/pipeline.go index 333b833f6..8992d8958 100644 --- a/api/ws/pipeline.go +++ b/api/ws/pipeline.go @@ -158,12 +158,13 @@ func handleCreateSolutions(conn *Connection, client *compute.Client, metadataCto // load defaults config, _ := env.LoadConfig() - if len(request.Task) == 0 { - request.Task = api.DefaultTaskType(request.TargetFeature.Type, request.ProblemType) - log.Infof("Defaulting task type to `%s`", request.Task) + metricTasks := request.Task + if len(metricTasks) == 0 { + metricTasks = api.DefaultTaskType(request.TargetFeature.Type, request.ProblemType) + log.Infof("Defaulting metric task type to `%s`", metricTasks) } if len(request.Metrics) == 0 { - request.Metrics = api.DefaultMetrics(request.Task) + request.Metrics = api.DefaultMetrics(metricTasks) log.Infof("Defaulting metrics to `%s`", strings.Join(request.Metrics, ",")) } if request.MaxTime == 0 { @@ -412,13 +413,27 @@ func handlePredict(conn *Connection, client *compute.Client, metadataCtor apiMod return } - // resolve the task so we know what type of data we should be expecting - requestTask, err := api.ResolveTask(dataStorage, meta.StorageName, targetVar, variables) + // get the exported model + exportedModel, err := modelStorage.FetchModelByID(request.FittedSolutionID) if err != nil { handleErr(conn, msg, err) return } + // if the task wasnt saved, determine it via the target and features + var modelTask *api.Task + if len(exportedModel.Task) > 0 { + modelTask = &api.Task{ + Task: exportedModel.Task, + } + } else { + modelTask, err = api.ResolveTask(dataStorage, meta.StorageName, targetVar, variables) + if err != nil { + handleErr(conn, msg, err) + return + } + } + // config objects required for ingest config, _ := env.LoadConfig() ingestConfig := task.NewConfig(config) @@ -430,21 +445,32 @@ func handlePredict(conn *Connection, client *compute.Client, metadataCtor apiMod SolutionID: sr.SolutionID, FittedSolutionID: request.FittedSolutionID, OutputPath: path.Join(config.D3MOutputDir, config.AugmentedSubFolder), + Task: modelTask, Target: targetVar, MetaStorage: metaStorage, DataStorage: dataStorage, SolutionStorage: solutionStorage, - ModelStorage: modelStorage, + ExportedModel: exportedModel, Config: &config, IngestConfig: ingestConfig, SourceDatasetID: meta.ID, } - datasetName, datasetPath, err := getPredictionDataset(requestTask, request, predictParams) + datasetName, datasetPath, err := getPredictionDataset(modelTask, request, predictParams) if err != nil { handleErr(conn, msg, errors.Wrap(err, "unable to create raw dataset")) return } + + // if the task is a segmentation task, run it against the base dataset + if api.HasTaskType(modelTask, compute.SegmentationTask) { + dsPred, err := metaStorage.FetchDataset(datasetName, true, true, true) + if err != nil { + handleErr(conn, msg, errors.Wrap(err, "unable to resolve prediction dataset")) + return + } + datasetPath = path.Join(env.ResolvePath(dsPred.Source, dsPred.Folder), compute.D3MDataSchema) + } predictParams.Dataset = datasetName predictParams.SchemaPath = datasetPath diff --git a/go.mod b/go.mod index 3bb0ce8a9..5c67f4875 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/russross/blackfriday v2.0.0+incompatible github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/stretchr/testify v1.6.1 - github.com/uncharted-distil/distil-compute v0.0.0-20211111182155-5ec97a35a8cc + github.com/uncharted-distil/distil-compute v0.0.0-20220818194426-a130f919e111 github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5 diff --git a/go.sum b/go.sum index 36eb1cfa9..86f82af75 100644 --- a/go.sum +++ b/go.sum @@ -213,8 +213,10 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/uncharted-distil/distil-compute v0.0.0-20211111182155-5ec97a35a8cc h1:I0fgw+Tb2rqFzj37Ux5SM0KMf48HZ0RxTbH+QGXoy5w= -github.com/uncharted-distil/distil-compute v0.0.0-20211111182155-5ec97a35a8cc/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8= +github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93 h1:UNSU3FX3h4k8wrzzXWLtX2kl4bb2AW7BqoV2FkQigRs= +github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8= +github.com/uncharted-distil/distil-compute v0.0.0-20220818194426-a130f919e111 h1:HRYDNq9tSNcqZ02mzfnp8Ee+piBUSGt6vXUjjZxKxIM= +github.com/uncharted-distil/distil-compute v0.0.0-20220818194426-a130f919e111/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8= github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY= github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc= github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y= diff --git a/main.go b/main.go index dd4227229..127e45790 100644 --- a/main.go +++ b/main.go @@ -166,7 +166,11 @@ func main() { pgDataStorageCtor := pg.NewDataStorage(postgresClientCtor, postgresBatchClientCtor, esMetadataStorageCtor) // instantiate the postgres solution storage constructor. - pgSolutionStorageCtor := pg.NewSolutionStorage(postgresClientCtor, esMetadataStorageCtor) + pgSolutionStorageCtor, err := pg.NewSolutionStorage(postgresClientCtor, esMetadataStorageCtor, config.PostgresUpdate) + if err != nil { + log.Errorf("%+v", err) + os.Exit(1) + } // Instantiate the solution compute client solutionClient, err := task.NewDefaultClient(config, userAgent, discoveryLogger) @@ -297,6 +301,7 @@ func main() { registerRoutePost(mux, "/distil/update/:dataset", routes.UpdateHandler(esMetadataStorageCtor, pgDataStorageCtor, config)) registerRoutePost(mux, "/distil/clone-result/:produce-request-id", routes.CloningResultsHandler(esMetadataStorageCtor, pgDataStorageCtor, pgSolutionStorageCtor, config)) registerRoutePost(mux, "/distil/clone/:dataset", routes.CloningHandler(esMetadataStorageCtor, pgDataStorageCtor, config)) + registerRoutePost(mux, "/distil/segment/:dataset/:variable", routes.SegmentationHandler(esMetadataStorageCtor, pgDataStorageCtor, config)) registerRoutePost(mux, "/distil/save-dataset/:dataset", routes.SaveDatasetHandler(esMetadataStorageCtor, pgDataStorageCtor, config)) registerRoutePost(mux, "/distil/add-field/:dataset", routes.AddFieldHandler(esMetadataStorageCtor, pgDataStorageCtor)) registerRoutePost(mux, "/distil/extract/:dataset", routes.ExtractHandler(esMetadataStorageCtor, pgDataStorageCtor, config)) diff --git a/public/components/CreateSolutionsForm.vue b/public/components/CreateSolutionsForm.vue index d1d95c740..02e722a0b 100644 --- a/public/components/CreateSolutionsForm.vue +++ b/public/components/CreateSolutionsForm.vue @@ -173,6 +173,8 @@ export default Vue.extend({ // flag as pending this.pending = true; // dispatch action that triggers request send to server + const selectedTask = routeGetters.getRouteSelectedTask(this.$store); + const taskToRun = selectedTask ? selectedTask.split(",") : null; const routeSplit = routeGetters.getRouteTrainTestSplit(this.$store); const defaultSplit = appGetters.getTrainTestSplit(this.$store); const timestampSplit = routeGetters.getRouteTimestampSplit(this.$store); @@ -193,6 +195,7 @@ export default Vue.extend({ quality: routeGetters.getModelQuality(this.$store), trainTestSplit: !!routeSplit ? routeSplit : defaultSplit, timestampSplitValue: timestampSplit, + task: taskToRun, } as SolutionRequestMsg; // Add optional values to the request @@ -208,13 +211,16 @@ export default Vue.extend({ this.pending = false; const dataMode = routeGetters.getDataMode(this.$store); const dataModeDefault = dataMode ? dataMode : DataMode.Default; + const taskUsed = selectedTask + ? selectedTask + : routeGetters.getRouteTask(this.$store); // transition to result screen const entry = createRouteEntry(RESULTS_ROUTE, { dataset: routeGetters.getRouteDataset(this.$store), target: routeGetters.getRouteTargetVariable(this.$store), solutionId: res.solutionId, - task: routeGetters.getRouteTask(this.$store), + task: taskUsed, dataMode: dataModeDefault, varModes: varModesToString( routeGetters.getDecodedVarModes(this.$store) diff --git a/public/components/SettingsModal.vue b/public/components/SettingsModal.vue index 9b502f76a..2f1b225cc 100644 --- a/public/components/SettingsModal.vue +++ b/public/components/SettingsModal.vue @@ -131,6 +131,18 @@

+ + + + {{ task }} + + + @@ -171,6 +183,7 @@ export default Vue.extend({ // fill this from the API later, first posting back the target's type // then getting a list of allowed scoring methods with keys, description selectedMetric: null, + selectedTask: null, trainingCount: 1, timestampSplitValue: new Date(), splitByTime: false, @@ -196,10 +209,43 @@ export default Vue.extend({ }); }, + multipleTasks(): boolean { + // hack to only really be true when classification and segmentation is possible + return ( + this.task.includes(TaskTypes.REMOTE_SENSING) && this.tasks.length > 1 + ); + }, + + tasks(): string[] { + // hack to only really be allow for classification and segmentation + return this.task + .split(",") + .filter((t) => t != TaskTypes.REMOTE_SENSING && t != TaskTypes.BINARY); + }, + task(): string { return routeGetters.getRouteTask(this.$store) ?? ""; }, + rebuildTask(): string { + // hack to submit only either classification or segmentation when dealing with remote sensing + if (this.multipleTasks) { + // if no task selected, then return null + if (this.selectedTask) { + return ( + TaskTypes.REMOTE_SENSING + + "," + + TaskTypes.BINARY + + "," + + this.selectedTask + ); + } + return null; + } + + return this.task; + }, + totalDataCount(): number { return datasetGetters.getIncludedTableDataNumRows(this.$store); }, @@ -310,6 +356,7 @@ export default Vue.extend({ modelTimeLimit: this.timeLimit, modelQuality: this.speedQuality, metrics: this.selectedMetric, + selectedTask: this.rebuildTask, trainTestSplit: this.trainingRatio, timestampSplit: this.hasTimeRange && this.splitByTime diff --git a/public/store/dataset/index.ts b/public/store/dataset/index.ts index c34c06289..61b2352a9 100644 --- a/public/store/dataset/index.ts +++ b/public/store/dataset/index.ts @@ -267,6 +267,7 @@ export interface TimeseriesExtrema { // task string definitions - should mirror those defined in the MIT/LL d3m problem schema export enum TaskTypes { CLASSIFICATION = "classification", + SEGMENTATION = "segmentation", REGRESSION = "regression", CLUSTERING = "clustering", LINK_PREDICTION = "linkPrediction", diff --git a/public/store/requests/actions.ts b/public/store/requests/actions.ts index f0355f89e..ce78abb5f 100644 --- a/public/store/requests/actions.ts +++ b/public/store/requests/actions.ts @@ -74,6 +74,7 @@ export interface SolutionRequestMsg { target: string; timestampSplitValue?: number; trainTestSplit: number; + task?: string[]; } // Solution status message used in web socket context @@ -596,6 +597,7 @@ export const actions = { filters: request.filters, trainTestSplit: request.trainTestSplit, timestampSplitValue: request.timestampSplitValue, + task: request.task, }); }); }, diff --git a/public/store/route/getters.ts b/public/store/route/getters.ts index c49422d79..42c3f3341 100644 --- a/public/store/route/getters.ts +++ b/public/store/route/getters.ts @@ -604,6 +604,14 @@ export const getters = { return task; }, + getRouteSelectedTask(state: Route, getters: any): string { + const selectedTask = state.query.selectedTask as string; + if (!selectedTask) { + return null; + } + return selectedTask; + }, + getDataMode(state: Route, getters: any): DataMode { const mode = state.query.dataMode as string; if (!mode) { diff --git a/public/store/route/module.ts b/public/store/route/module.ts index 3ea2e2c8d..09e6c7f32 100644 --- a/public/store/route/module.ts +++ b/public/store/route/module.ts @@ -134,6 +134,7 @@ export const getters = { getGeoZoom: read(moduleGetters.getGeoZoom), getGroupingType: read(moduleGetters.getGroupingType), getRouteTask: read(moduleGetters.getRouteTask), + getRouteSelectedTask: read(moduleGetters.getRouteSelectedTask), getColorScale: read(moduleGetters.getColorScale), getColorScaleVariable: read(moduleGetters.getColorScaleVariable), getImageLayerScale: read(moduleGetters.getImageLayerScale), diff --git a/public/util/routes.ts b/public/util/routes.ts index c3f78340b..23d31a8c9 100644 --- a/public/util/routes.ts +++ b/public/util/routes.ts @@ -59,6 +59,7 @@ export interface RouteArgs { resultTrainingVarsSearch?: string; trainingVarsSearch?: string; task?: string; + selectedTask?: string; dataMode?: string; varModes?: string; varRanked?: string; diff --git a/run.sh b/run.sh index dfc5cfd03..62525d636 100755 --- a/run.sh +++ b/run.sh @@ -35,6 +35,9 @@ export TILE_REQUEST_URL=https://server.arcgisonline.com/ArcGIS/rest/services/Wor export INGEST_SAMPLE_ROW_LIMIT=200000 # export MAX_TRAINING_ROWS=500 # export MAX_TEST_ROWS=500 +export PG_UPDATE=true +export SEGMENTATION_ENABLED=true +export REMOTE_SENSING_GPU_BATCH_SIZE=4 ulimit -n 4096