diff --git a/.gcloudignore b/.gcloudignore index 9fdaf41b4..cc86cc624 100644 --- a/.gcloudignore +++ b/.gcloudignore @@ -6,4 +6,5 @@ e2e/ .gitignore *.yaml *.md -secrets.json \ No newline at end of file +secrets.json +.meta \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2bebf504f..711ab3c41 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ secrets.json *.db local_test.go -vendor +vendor/ datly other mydb @@ -16,4 +16,4 @@ logs .extension .datly *.zip -local \ No newline at end of file +v1/ diff --git a/cmd/command/run.go b/cmd/command/run.go index a084275d2..db3ba7a1f 100644 --- a/cmd/command/run.go +++ b/cmd/command/run.go @@ -5,6 +5,7 @@ import ( "github.com/viant/afs/file" "github.com/viant/afs/url" "github.com/viant/datly/cmd/options" + "github.com/viant/datly/gateway" "github.com/viant/datly/gateway/runtime/standalone" "github.com/viant/datly/internal/setter" ) @@ -42,5 +43,16 @@ func (s *Service) run(ctx context.Context, run *options.Run) (*standalone.Server _ = s.fs.Copy(ctx, parent, s.config.Config.PluginsURL) } s.config.Version = run.Version + if run.MCPPort != nil || run.MCPAuthURL != "" || run.MCPIssuerURL != "" || run.MCPAuthMode != "" { + if s.config.Config.MCP == nil { + s.config.Config.MCP = &gateway.ModelContextProtocol{} + } + if run.MCPPort != nil { + s.config.Config.MCP.Port = run.MCPPort + } + setter.SetStringIfEmpty(&s.config.Config.MCP.OAuth2ConfigURL, run.MCPAuthURL) + setter.SetStringIfEmpty(&s.config.Config.MCP.IssuerURL, run.MCPIssuerURL) + setter.SetStringIfEmpty(&s.config.Config.MCP.AuthorizerMode, run.MCPAuthMode) + } return standalone.New(ctx, standalone.WithConfig(s.config)) } diff --git a/cmd/command/service.go b/cmd/command/service.go index fe5de831f..88cd4962c 100644 --- a/cmd/command/service.go +++ b/cmd/command/service.go @@ -62,6 +62,12 @@ func (s *Service) Exec(ctx context.Context, opts *options.Options) error { if opts.Translate != nil { return s.Translate(ctx, opts) } + if opts.Transcribe != nil { + return s.Transcribe(ctx, opts) + } + if opts.Validate != nil { + return s.Validate(ctx, opts) + } if opts.Mcp != nil { return s.Mcp(ctx, opts) diff --git a/cmd/command/transcribe.go b/cmd/command/transcribe.go new file mode 100644 index 000000000..8eec70e0a --- /dev/null +++ b/cmd/command/transcribe.go @@ -0,0 +1,2054 @@ +package command + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path" + "path/filepath" + "reflect" + "regexp" + "strings" + + "github.com/viant/afs" + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/gateway" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + shapeColumn "github.com/viant/datly/repository/shape/column" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/xgen" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/scy" + "github.com/viant/scy/auth/jwt/signer" + "github.com/viant/scy/auth/jwt/verifier" + "github.com/viant/tagly/format/text" + "github.com/viant/xreflect" + "gopkg.in/yaml.v3" +) + +func (s *Service) Transcribe(ctx context.Context, opts *options.Options) error { + transcribe := opts.Transcribe + if transcribe == nil { + return fmt.Errorf("transcribe options not set") + } + compiler := shapeCompile.New() + loader := shapeLoad.New() + var sources []string + for _, sourceURL := range transcribe.Source { + _, name := url.Split(sourceURL, file.Scheme) + dql, err := s.readSource(ctx, sourceURL) + if err != nil { + return fmt.Errorf("failed to read %s: %w", sourceURL, err) + } + shapeSource := &shape.Source{ + Name: strings.TrimSuffix(name, path.Ext(name)), + Path: url.Path(sourceURL), + DQL: strings.TrimSpace(dql), + Connector: transcribe.DefaultConnectorName(), + } + planResult, err := compiler.Compile(ctx, shapeSource, transcribeCompileOptions(transcribe)...) + if err != nil { + return fmt.Errorf("failed to compile %s: %w", sourceURL, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult, shape.WithLoadTypeContextPackages(true)) + if err != nil { + return fmt.Errorf("failed to load %s: %w", sourceURL, err) + } + component, ok := shapeLoad.ComponentFrom(componentArtifact) + if !ok { + return fmt.Errorf("unexpected component artifact for %s", sourceURL) + } + if componentArtifact.Resource != nil && len(transcribe.Connectors) > 0 { + applyConnectorsToResource(componentArtifact.Resource, transcribe.Connectors) + discoverColumns(ctx, componentArtifact.Resource) + shapeLoad.RefineSummarySchemas(componentArtifact.Resource) + } + prepareResourceForTranscribeCodegen(componentArtifact.Resource, component) + codegenResult, err := s.generateTranscribeTypes(sourceURL, dql, transcribe, componentArtifact.Resource, component) + if err != nil { + return err + } + if codegenResult != nil { + alignGeneratedPackageAliases(componentArtifact.Resource, component, codegenResult.PackageDir, codegenResult.PackagePath, codegenResult.PackageName) + } + if !transcribe.SkipYAML { + if err = s.persistTranscribeRoute(ctx, transcribe, sourceURL, dql, componentArtifact.Resource, component, codegenResult); err != nil { + return err + } + } + sources = append(sources, filepath.Clean(url.Path(sourceURL))) + } + return s.persistTranscribeDependencies(ctx, transcribe, sources) +} + +func (s *Service) persistTranscribeDependencies(ctx context.Context, transcribe *options.Transcribe, sources []string) error { + depURL := url.Join(transcribe.Repository, "Datly", "dependencies") + depURL = url.Normalize(depURL, file.Scheme) + if len(transcribe.Connectors) > 0 { + var connectors []connEntry + for _, c := range transcribe.Connectors { + parts := strings.SplitN(c, "|", 4) + if len(parts) >= 3 { + connectors = append(connectors, connEntry{Name: parts[0], Driver: parts[1], DSN: parts[2]}) + } + } + if len(connectors) > 0 { + connURL := url.Join(depURL, "connectors.yaml") + existing := loadExistingConnectors(ctx, s.fs, connURL) + merged := mergeConnectors(existing, connectors) + connMap := map[string]any{"Connectors": merged} + data, err := yaml.Marshal(connMap) + if err != nil { + return err + } + if err = s.fs.Upload(ctx, connURL, file.DefaultFileOsMode, strings.NewReader(string(data))); err != nil { + return fmt.Errorf("failed to persist connections: %w", err) + } + } + } + + cfgURL := url.Join(transcribe.Repository, "Datly", "config.json") + cfg := s.seedTranscribeConfig(ctx, transcribe) + if cfg.SyncFrequencyMs == 0 { + cfg.SyncFrequencyMs = 2000 + } + cfg.Meta.Init() + if cfg.Meta.StatusURI == "" { + cfg.Meta.StatusURI = "/v1/api/status" + } + payload := map[string]any{ + "APIPrefix": cfg.APIPrefix, + "DependencyURL": depURL, + "Endpoint": map[string]any{"Port": 8080}, + "SyncFrequencyMs": cfg.SyncFrequencyMs, + "Meta": cfg.Meta, + } + if transcribe.APIPrefix != "" { + payload["APIPrefix"] = transcribe.APIPrefix + } + if len(cfg.APIKeys) > 0 { + payload["APIKeys"] = cfg.APIKeys + } + if cfg.JWTValidator != nil { + payload["JWTValidator"] = cfg.JWTValidator + } + if cfg.JwtSigner != nil { + payload["JwtSigner"] = cfg.JwtSigner + } + if transcribe.SkipYAML { + payload["DQLBootstrap"] = map[string]any{"Sources": mergeStrings(existingBootstrapSources(ctx, s.fs, cfgURL), sources)} + } else { + payload["RouteURL"] = url.Normalize(url.Join(transcribe.Repository, "Datly", "routes"), file.Scheme) + } + cfgData, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return err + } + if err = s.fs.Upload(ctx, cfgURL, file.DefaultFileOsMode, strings.NewReader(string(cfgData))); err != nil { + return fmt.Errorf("failed to persist config: %w", err) + } + return nil +} + +func (s *Service) seedTranscribeConfig(ctx context.Context, transcribe *options.Transcribe) *gateway.Config { + seed := &gateway.Config{} + projectCfg := filepath.Join(transcribe.Project, "config.json") + if data, err := s.fs.DownloadWithURL(ctx, projectCfg); err == nil { + _ = json.Unmarshal(data, seed) + } + applyAuth(seed, &transcribe.Auth) + return seed +} + +func applyAuth(cfg *gateway.Config, auth *options.Auth) { + if cfg == nil || auth == nil { + return + } + if strings.TrimSpace(auth.RSA) != "" { + cfg.JWTValidator = &verifier.Config{RSA: getScyResources(auth.RSA)} + cfg.JwtSigner = &signer.Config{RSA: getScyResource(strings.Split(auth.RSA, ";")[0])} + } + if strings.TrimSpace(auth.HMAC) != "" { + cfg.JWTValidator = &verifier.Config{HMAC: getScyResource(auth.HMAC)} + cfg.JwtSigner = &signer.Config{HMAC: getScyResource(auth.HMAC)} + } +} + +func getScyResource(location string) *scy.Resource { + pair := strings.Split(location, "|") + res := &scy.Resource{URL: pair[0]} + if len(pair) > 1 { + res.Key = pair[1] + } + res.URL = url.Normalize(res.URL, file.Scheme) + return res +} + +func getScyResources(location string) []*scy.Resource { + var result []*scy.Resource + for _, item := range strings.Split(location, "-") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + result = append(result, getScyResource(item)) + } + return result +} + +func existingBootstrapSources(ctx context.Context, fs afs.Service, cfgURL string) []string { + data, err := fs.DownloadWithURL(ctx, cfgURL) + if err != nil { + return nil + } + cfg := &gateway.Config{} + if err = json.Unmarshal(data, cfg); err != nil || cfg.DQLBootstrap == nil { + return nil + } + return append([]string{}, cfg.DQLBootstrap.Sources...) +} + +func mergeStrings(existing, incoming []string) []string { + seen := map[string]bool{} + var result []string + for _, item := range append(existing, incoming...) { + item = strings.TrimSpace(item) + if item == "" || seen[item] { + continue + } + seen[item] = true + result = append(result, item) + } + return result +} + +type connEntry struct { + Name string `yaml:"Name"` + Driver string `yaml:"Driver"` + DSN string `yaml:"DSN"` +} + +func loadExistingConnectors(ctx context.Context, fs afs.Service, connURL string) []connEntry { + data, err := fs.DownloadWithURL(ctx, connURL) + if err != nil { + return nil + } + var doc struct { + Connectors []connEntry `yaml:"Connectors"` + } + if err := yaml.Unmarshal(data, &doc); err != nil { + return nil + } + return doc.Connectors +} + +func mergeConnectors(existing, incoming []connEntry) []connEntry { + byName := map[string]connEntry{} + var order []string + for _, c := range existing { + if _, ok := byName[c.Name]; !ok { + order = append(order, c.Name) + } + byName[c.Name] = c + } + for _, c := range incoming { + if _, ok := byName[c.Name]; !ok { + order = append(order, c.Name) + } + byName[c.Name] = c + } + result := make([]connEntry, 0, len(order)) + for _, name := range order { + result = append(result, byName[name]) + } + return result +} + +func (s *Service) readSource(ctx context.Context, sourceURL string) (string, error) { + payload, err := s.fs.DownloadWithURL(ctx, sourceURL) + if err != nil { + return "", err + } + return string(payload), nil +} + +func (s *Service) persistTranscribeRoute(ctx context.Context, transcribe *options.Transcribe, sourceURL, dql string, resource *view.Resource, component *shapeLoad.Component, codegenResult *xgen.ComponentCodegenResult) error { + sourcePath := filepath.Clean(url.Path(sourceURL)) + stem := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + routeRoot := url.Join(transcribe.Repository, "Datly", "routes") + routeYAML := url.Join(routeRoot, stem+".yaml") + if err := s.applyGeneratedMutableArtifacts(ctx, routeRoot, resource, component, codegenResult); err != nil { + return err + } + if resource != nil && codegenResult != nil && strings.TrimSpace(codegenResult.VeltyFilePath) != "" { + if source, err := os.ReadFile(filepath.Clean(codegenResult.VeltyFilePath)); err == nil { + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + root := lookupNamedView(resource, rootView) + if root == nil && len(resource.Views) > 0 { + root = resource.Views[0] + } + if root != nil { + if root.Template == nil { + root.Template = view.NewTemplate(stripLeadingRouteDirective(string(source))) + } else { + root.Template.Source = stripLeadingRouteDirective(string(source)) + } + if rel, err := filepath.Rel(filepath.Clean(codegenResult.PackageDir), filepath.Clean(codegenResult.VeltyFilePath)); err == nil { + root.Template.SourceURL = filepath.ToSlash(rel) + } + } + } + } + if resource != nil { + normalizeResourceSchemaPackages(resource) + for _, item := range resource.Views { + if item == nil || item.Template == nil || strings.TrimSpace(item.Template.Source) == "" { + continue + } + if strings.HasSuffix(strings.TrimSpace(item.Template.SourceURL), "/patch.sql") || strings.EqualFold(strings.TrimSpace(item.Name), strings.TrimSpace(component.RootView)) { + item.Template.Source = stripLeadingRouteDirective(item.Template.Source) + } + sqlRel := strings.TrimSpace(item.Template.SourceURL) + if sqlRel == "" { + sqlRel = path.Join(stem, item.Name+".sql") + } + sqlDest := path.Join(url.Path(routeRoot), filepath.ToSlash(sqlRel)) + if err := s.fs.Upload(ctx, sqlDest, file.DefaultFileOsMode, strings.NewReader(item.Template.Source)); err != nil { + return fmt.Errorf("failed to persist sql %s: %w", sqlDest, err) + } + item.Template.SourceURL = sqlRel + } + } + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + if rootView == "" && resource != nil && len(resource.Views) > 0 && resource.Views[0] != nil { + rootView = resource.Views[0].Name + } + method, uri := transcribeRulePath(dql, stem, transcribe.APIPrefix, component) + routeView := &view.View{ + Reference: shared.Reference{Ref: rootView}, + Name: rootView, + } + if root := lookupNamedView(resource, rootView); root != nil { + viewCopy := *root + routeView = &viewCopy + routeView.Reference = shared.Reference{Ref: rootView} + } + route := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Service: serviceTypeForMethod(method), + Output: contract.Output{ + CaseFormat: text.CaseFormatLowerCamel, + }, + }, + View: routeView, + } + if root := lookupNamedView(resource, rootView); root != nil && root.Connector != nil { + ref := strings.TrimSpace(root.Connector.Ref) + if ref == "" { + ref = strings.TrimSpace(root.Connector.Name) + } + if ref != "" { + route.View.Connector = view.NewRefConnector(ref) + route.View.Connector.Name = ref + } + } + if component != nil { + route.TypeContext = component.TypeContext + if output := component.OutputParameters(); len(output) > 0 { + route.Contract.Output = contract.Output{ + Cardinality: state.Many, + CaseFormat: text.CaseFormatLowerCamel, + Type: state.Type{ + Parameters: output, + }, + } + } + if component.Directives != nil && component.Directives.MCP != nil { + route.Name = strings.TrimSpace(component.Directives.MCP.Name) + route.Description = strings.TrimSpace(component.Directives.MCP.Description) + route.DescriptionURI = strings.TrimSpace(component.Directives.MCP.DescriptionPath) + } + } + if component != nil && (len(component.Input) > 0 || len(component.Meta) > 0) { + params := transcribeInputParameters(component, resource) + if len(params) > 0 { + normalizeParameterSchemas(params) + route.Contract.Input.Type.Parameters = normalizeParameterTypeNameTags(params) + } + } + if component != nil { + normalizeComponentStateSchemas(component) + } + payload := &shapeRuleFile{ + Routes: []*repository.Component{route}, + Resource: sanitizeResourceForRouteYAML(resource), + With: transcribeSharedResourceRefs(resource), + } + if payload.Resource != nil { + normalizeResourceSchemaPackages(payload.Resource) + promoteAnonymousParameterTypeDefinitions(payload.Resource) + backfillResourceColumnDataTypes(payload.Resource) + canonicalizeResourceTypeDefinitions(payload.Resource) + } + if len(payload.Routes) > 0 { + if payload.Resource != nil && payload.Routes[0] != nil && payload.Routes[0].View != nil { + alignViewParameterSchemasToResourceTypes(payload.Routes[0].View, payload.Resource) + } + normalizeParameterSchemas(payload.Routes[0].Contract.Input.Type.Parameters) + normalizeParameterSchemas(payload.Routes[0].Contract.Output.Type.Parameters) + } + if component != nil && component.TypeContext != nil { + payload.TypeContext = component.TypeContext + } + if payload.Resource != nil && codegenResult != nil && strings.TrimSpace(codegenResult.VeltyFilePath) != "" { + if source, err := os.ReadFile(filepath.Clean(codegenResult.VeltyFilePath)); err == nil { + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + root := lookupNamedView(payload.Resource, rootView) + if root == nil && len(payload.Resource.Views) > 0 { + root = payload.Resource.Views[0] + } + if root != nil { + if root.Template == nil { + root.Template = view.NewTemplate(stripLeadingRouteDirective(string(source))) + } else { + root.Template.Source = stripLeadingRouteDirective(string(source)) + } + if rel, err := filepath.Rel(filepath.Clean(codegenResult.PackageDir), filepath.Clean(codegenResult.VeltyFilePath)); err == nil { + root.Template.SourceURL = filepath.ToSlash(rel) + } + } + } + } + data, err := yaml.Marshal(payload) + if err != nil { + return err + } + data, err = ensureSharedResourceRefsYAML(data, payload.With) + if err != nil { + return err + } + data, err = normalizeConnectorRefsYAML(data) + if err != nil { + return err + } + data, err = normalizeRouteViewRefsYAML(data) + if err != nil { + return err + } + data, err = normalizeRouteComponentEmbeddingYAML(data) + if err != nil { + return err + } + if err = s.fs.Upload(ctx, routeYAML, file.DefaultFileOsMode, strings.NewReader(string(data))); err != nil { + return fmt.Errorf("failed to persist route yaml %s: %w", routeYAML, err) + } + return nil +} + +func transcribeSharedResourceRefs(resource *view.Resource) []string { + if resource == nil { + return nil + } + var result []string + if len(collectResourceConnectorRefs(resource)) > 0 { + result = append(result, view.ResourceConnectors) + } + if len(resource.CacheProviders) > 0 { + result = append(result, "cache") + } + return result +} + +func (s *Service) generateTranscribeTypes(sourceAbsPath, dql string, transcribe *options.Transcribe, resource *view.Resource, component *shapeLoad.Component) (*xgen.ComponentCodegenResult, error) { + if component == nil || component.TypeContext == nil || resource == nil { + return nil, nil + } + ctx := component.TypeContext + projectDir := findProjectDir(sourceAbsPath) + if projectDir == "" { + projectDir = transcribe.Project + } + codegen := &xgen.ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: !transcribe.SkipYAML, + WithContract: true, + } + if pkgPath, pkgDir, pkgName := resolvedTranscribeTypeOutput(projectDir, ctx.PackagePath); pkgPath != "" { + codegen.PackagePath = pkgPath + codegen.PackageDir = pkgDir + codegen.PackageName = pkgName + } + if method, uri := resolvedTranscribeRoute(sourceAbsPath, dql, transcribe.APIPrefix); uri != "" { + component.Method = method + component.URI = uri + } + return codegen.Generate() +} + +func (s *Service) applyGeneratedMutableArtifacts(ctx context.Context, routeRoot string, resource *view.Resource, component *shapeLoad.Component, codegenResult *xgen.ComponentCodegenResult) error { + if resource == nil || component == nil || codegenResult == nil || codegenResult.PackageDir == "" { + return nil + } + uploaded := map[string]string{} + packageDir := filepath.Clean(codegenResult.PackageDir) + for _, generated := range codegenResult.GeneratedFiles { + if strings.TrimSpace(generated) == "" || !strings.HasSuffix(strings.TrimSpace(generated), ".sql") { + continue + } + absPath := filepath.Clean(generated) + rel, err := filepath.Rel(packageDir, absPath) + if err != nil { + continue + } + rel = filepath.ToSlash(rel) + if strings.HasPrefix(rel, "../") { + continue + } + data, err := os.ReadFile(absPath) + if err != nil { + return fmt.Errorf("failed to read generated sql %s: %w", absPath, err) + } + content := string(data) + if strings.TrimSpace(codegenResult.VeltyFilePath) != "" && filepath.Clean(codegenResult.VeltyFilePath) == absPath { + content = stripLeadingRouteDirective(content) + } + dest := path.Join(url.Path(routeRoot), rel) + if err = s.fs.Upload(ctx, dest, file.DefaultFileOsMode, strings.NewReader(content)); err != nil { + return fmt.Errorf("failed to persist generated sql %s: %w", dest, err) + } + uploaded[rel] = content + } + if len(uploaded) == 0 { + return nil + } + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + root := lookupNamedView(resource, rootView) + if root == nil && resource != nil && len(resource.Views) > 0 { + root = resource.Views[0] + } + if root != nil && strings.TrimSpace(codegenResult.VeltyFilePath) != "" { + if rel, err := filepath.Rel(packageDir, filepath.Clean(codegenResult.VeltyFilePath)); err == nil { + rel = filepath.ToSlash(rel) + if source, ok := uploaded[rel]; ok { + if root.Template == nil { + root.Template = view.NewTemplate(source) + } else { + root.Template.Source = source + } + root.Template.SourceURL = rel + preserveTemplateParameters(root, component.InputParameters()) + } + } + } + for _, item := range resource.Views { + if item == nil || item.Template == nil { + continue + } + rel := strings.TrimSpace(item.Template.SourceURL) + if item.Template.DeclaredParametersOnly { + item.Template.Parameters = append(state.Parameters{}, resource.Parameters.UsedBy(item.Template.Source)...) + } else { + preserveTemplateParameters(item, resource.Parameters.UsedBy(item.Template.Source)) + preserveTemplateParameters(item, dependentTemplateParameters(item.Template.Parameters, resource.Parameters)) + } + if len(item.Template.Parameters) > 0 { + item.Template.UseParameterStateType = true + } + if rel == "" { + continue + } + if source, ok := uploaded[rel]; ok { + item.Template.Source = source + } + } + return nil +} + +func preserveTemplateParameters(aView *view.View, params state.Parameters) { + if aView == nil || aView.Template == nil || len(params) == 0 { + return + } + if aView.Template.DeclaredParametersOnly { + return + } + seen := map[string]bool{} + for _, item := range aView.Template.Parameters { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + seen[strings.ToLower(strings.TrimSpace(item.Name))] = true + } + for _, param := range params { + if param == nil || strings.TrimSpace(param.Name) == "" { + continue + } + switch param.In.Kind { + case state.KindOutput, state.KindMeta, state.KindAsync: + continue + } + key := strings.ToLower(strings.TrimSpace(param.Name)) + if seen[key] { + continue + } + aView.Template.Parameters = append(aView.Template.Parameters, param) + seen[key] = true + } +} + +func transcribeInputParameters(component *shapeLoad.Component, resource *view.Resource) state.Parameters { + if component == nil { + return nil + } + params := make(state.Parameters, 0, len(component.Input)+len(component.Meta)+4) + seen := map[string]bool{} + declared := map[string]bool{} + appendParam := func(param *state.Parameter) { + if param == nil { + return + } + key := strings.ToLower(strings.TrimSpace(param.Name)) + if key == "" || seen[key] { + return + } + cloned := *param + if param.Schema != nil { + cloned.Schema = param.Schema.Clone() + } + if param.Output != nil { + output := *param.Output + if param.Output.Schema != nil { + output.Schema = param.Output.Schema.Clone() + } + cloned.Output = &output + } + params = append(params, &cloned) + seen[key] = true + } + for _, item := range component.Input { + if item != nil { + if name := strings.ToLower(strings.TrimSpace(item.Name)); name != "" { + declared[name] = true + } + appendParam(&item.Parameter) + } + } + rootView := lookupNamedView(resource, strings.TrimSpace(component.RootView)) + if rootView != nil && rootView.Template != nil { + for _, item := range rootView.Template.Parameters { + if item == nil { + continue + } + if !declared[strings.ToLower(strings.TrimSpace(item.Name))] { + continue + } + appendParam(item) + } + } + for _, item := range component.Meta { + if item != nil { + if name := strings.ToLower(strings.TrimSpace(item.Name)); name != "" { + declared[name] = true + } + appendParam(&item.Parameter) + } + } + return params +} + +func prepareResourceForTranscribeCodegen(resource *view.Resource, component *shapeLoad.Component) { + if resource == nil || component == nil { + return + } + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + root := lookupNamedView(resource, rootView) + if root == nil && len(resource.Views) > 0 { + root = resource.Views[0] + } + if root != nil && root.Template != nil { + preserveTemplateParameters(root, component.InputParameters()) + preserveTemplateParameters(root, resource.Parameters) + if len(root.Template.Parameters) > 0 { + root.Template.UseParameterStateType = true + } + } + for _, item := range resource.Views { + if item == nil || item.Template == nil { + continue + } + if item.Template.DeclaredParametersOnly { + item.Template.Parameters = append(state.Parameters{}, resource.Parameters.UsedBy(item.Template.Source)...) + } else { + preserveTemplateParameters(item, resource.Parameters.UsedBy(item.Template.Source)) + preserveTemplateParameters(item, dependentTemplateParameters(item.Template.Parameters, resource.Parameters)) + } + if len(item.Template.Parameters) > 0 { + item.Template.UseParameterStateType = true + } + } +} + +func dependentTemplateParameters(params state.Parameters, resourceParams state.Parameters) state.Parameters { + if len(params) == 0 || len(resourceParams) == 0 { + return nil + } + seen := map[string]bool{} + result := make(state.Parameters, 0) + for _, param := range params { + if param == nil || param.In == nil { + continue + } + switch param.In.Kind { + case state.KindParam: + name := strings.TrimSpace(param.In.Name) + if name == "" { + continue + } + key := strings.ToLower(name) + if seen[key] { + continue + } + if dep := resourceParams.Lookup(name); dep != nil { + result = append(result, dep) + seen[key] = true + } + } + } + return result +} + +func stripLeadingRouteDirective(content string) string { + trimmed := strings.TrimSpace(content) + if !strings.HasPrefix(trimmed, "/*") { + return content + } + end := strings.Index(trimmed, "*/") + if end == -1 { + return content + } + header := trimmed[2:end] + if !strings.Contains(header, `"URI"`) || !strings.Contains(header, `"Method"`) { + return content + } + return strings.TrimSpace(trimmed[end+2:]) + "\n" +} + +func lookupNamedView(resource *view.Resource, name string) *view.View { + if resource == nil || strings.TrimSpace(name) == "" { + return nil + } + for _, item := range resource.Views { + if item != nil && strings.EqualFold(strings.TrimSpace(item.Name), strings.TrimSpace(name)) { + return item + } + } + return nil +} + +var routeDirectivePattern = regexp.MustCompile(`\$route\(\s*['"]([^'"]+)['"](?:\s*,\s*['"]([^'"]+)['"])?`) + +func resolvedTranscribeRoute(sourcePath, dql, apiPrefix string) (string, string) { + matches := routeDirectivePattern.FindStringSubmatch(dql) + if len(matches) > 0 { + method := strings.ToUpper(strings.TrimSpace(matches[2])) + if method == "" { + method = "GET" + } + return method, strings.TrimSpace(matches[1]) + } + stem := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + uri := "/" + strings.Trim(stem, "/") + if prefix := strings.TrimSpace(apiPrefix); prefix != "" { + uri = strings.TrimRight(prefix, "/") + uri + } + return "GET", uri +} +func resolvedTranscribeTypeOutput(projectDir, packagePath string) (string, string, string) { + projectDir = strings.TrimSpace(projectDir) + packagePath = strings.TrimSpace(packagePath) + if projectDir == "" || packagePath == "" { + return "", "", "" + } + modulePath, err := transcribeModulePath(filepath.Join(projectDir, "go.mod")) + if err != nil || modulePath == "" { + return "", "", "" + } + prefix := strings.TrimRight(modulePath, "/") + "/" + if !strings.HasPrefix(packagePath, prefix) { + return "", "", "" + } + rel := strings.TrimPrefix(packagePath, prefix) + rel = sanitizeTypeNamespace(rel) + if rel == "" { + return "", "", "" + } + pkgDir := filepath.Join(projectDir, filepath.FromSlash(rel)) + pkgName := filepath.Base(rel) + return strings.TrimRight(modulePath, "/") + "/" + rel, pkgDir, pkgName +} + +func transcribeModulePath(goModPath string) (string, error) { + data, err := os.ReadFile(goModPath) + if err != nil { + return "", err + } + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "module ") { + return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil + } + } + return "", fmt.Errorf("module path not found in %s", goModPath) +} +func transcribeCompileOptions(transcribe *options.Transcribe) []shape.CompileOption { + var opts []shape.CompileOption + if transcribe.Strict { + opts = append(opts, shape.WithCompileStrict(true)) + } + opts = append(opts, shape.WithLinkedTypes(false)) + namespace := strings.TrimSpace(transcribe.Namespace) + module := strings.TrimSpace(transcribe.Module) + typeOutput := strings.TrimSpace(transcribe.TypeOutput) + if typeOutput == "" || typeOutput == "." { + typeOutput = module + } + if namespace != "" { + sanitizedNamespace := sanitizeTypeNamespace(namespace) + pkgDir := filepath.Join(typeOutput, sanitizedNamespace) + pkgName := filepath.Base(sanitizedNamespace) + opts = append(opts, shape.WithTypeContextPackageDir(pkgDir)) + opts = append(opts, shape.WithTypeContextPackageName(pkgName)) + } + return opts +} + +func sanitizeTypeNamespace(namespace string) string { + parts := strings.Split(strings.ReplaceAll(strings.TrimSpace(namespace), "\\", "/"), "/") + for i, part := range parts { + part = strings.TrimSpace(part) + switch part { + case "": + continue + case "vendor": + part = "vendorsrc" + default: + part = sanitizeTypeNamespaceSegment(part) + } + parts[i] = part + } + return path.Join(parts...) +} + +func sanitizeTypeNamespaceSegment(segment string) string { + var b strings.Builder + for _, r := range segment { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + case r >= 'A' && r <= 'Z': + b.WriteRune(r + ('a' - 'A')) + case r >= '0' && r <= '9': + b.WriteRune(r) + case r == '_' || r == '-': + b.WriteRune('_') + } + } + if b.Len() == 0 { + return "generated" + } + ret := b.String() + if ret[0] >= '0' && ret[0] <= '9' { + return "p" + ret + } + return ret +} + +func transcribeRulePath(_ string, ruleName, apiPrefix string, component *shapeLoad.Component) (string, string) { + method := "GET" + uri := "/" + strings.Trim(strings.TrimSpace(ruleName), "/") + if prefix := strings.TrimSpace(apiPrefix); prefix != "" { + uri = strings.TrimRight(prefix, "/") + uri + } + if component != nil { + if u := strings.TrimSpace(component.URI); u != "" { + uri = u + } + if m := strings.TrimSpace(strings.ToUpper(component.Method)); m != "" { + method = m + } + } + return method, uri +} + +func discoverColumns(ctx context.Context, resource *view.Resource) { + if resource == nil { + return + } + detector := shapeColumn.New() + for _, aView := range resource.Views { + if aView == nil { + continue + } + columns, err := detector.Resolve(ctx, resource, aView) + if err == nil && len(columns) > 0 { + aView.Columns = columns + } + } +} + +func applyConnectorsToResource(resource *view.Resource, connectors []string) { + if resource == nil || len(connectors) == 0 { + return + } + defaultName := "" + for _, c := range connectors { + parts := strings.SplitN(c, "|", 4) + if len(parts) < 1 { + continue + } + name := strings.TrimSpace(parts[0]) + if name == "" { + continue + } + if defaultName == "" { + defaultName = name + } + if len(parts) >= 3 { + resource.AddConnectors(view.NewConnector(name, strings.TrimSpace(parts[1]), strings.TrimSpace(parts[2]))) + } + } + if defaultName == "" { + return + } + for _, v := range resource.Views { + if v != nil && v.Connector == nil { + v.Connector = view.NewRefConnector(defaultName) + } + } +} + +func sanitizeResourceForRouteYAML(resource *view.Resource) *view.Resource { + if resource == nil { + return nil + } + cloned := *resource + cloned.Parameters = normalizeParameterTypeNameTags(cloneParameters(resource.Parameters)) + if refs := collectResourceConnectorRefs(resource); len(refs) > 0 { + cloned.Connectors = make([]*view.Connector, 0, len(refs)) + for _, ref := range refs { + refConnector := view.NewRefConnector(ref) + refConnector.Name = ref + cloned.Connectors = append(cloned.Connectors, refConnector) + } + } else { + cloned.Connectors = nil + } + if len(resource.Views) > 0 { + cloned.Views = make(view.Views, 0, len(resource.Views)) + for _, item := range resource.Views { + if item == nil { + continue + } + viewCopy := *item + if item.Connector != nil { + ref := strings.TrimSpace(item.Connector.Ref) + if ref == "" { + ref = strings.TrimSpace(item.Connector.Name) + } + if ref != "" { + refConnector := view.NewRefConnector(ref) + refConnector.Name = ref + viewCopy.Connector = refConnector + } else { + viewCopy.Connector = nil + } + } + cloned.Views = append(cloned.Views, &viewCopy) + } + } else { + cloned.Views = nil + } + return &cloned +} + +func collectResourceConnectorRefs(resource *view.Resource) []string { + if resource == nil { + return nil + } + seen := map[string]bool{} + var result []string + appendRef := func(connector *view.Connector) { + if connector == nil { + return + } + ref := strings.TrimSpace(connector.Ref) + if ref == "" { + ref = strings.TrimSpace(connector.Name) + } + if ref == "" || seen[ref] { + return + } + seen[ref] = true + result = append(result, ref) + } + var visitView func(aView *view.View) + visitView = func(aView *view.View) { + if aView == nil { + return + } + appendRef(aView.Connector) + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + visitView(&rel.Of.View) + } + } + for _, connector := range resource.Connectors { + appendRef(connector) + } + for _, aView := range resource.Views { + visitView(aView) + } + return result +} + +func cloneParameters(params state.Parameters) state.Parameters { + if len(params) == 0 { + return nil + } + result := make(state.Parameters, 0, len(params)) + for _, item := range params { + if item == nil { + continue + } + cloned := *item + if item.Schema != nil { + cloned.Schema = item.Schema.Clone() + } + if item.Output != nil { + output := *item.Output + if item.Output.Schema != nil { + output.Schema = item.Output.Schema.Clone() + } + cloned.Output = &output + } + result = append(result, &cloned) + } + return result +} + +func normalizeParameterTypeNameTags(params state.Parameters) state.Parameters { + if len(params) == 0 { + return params + } + for _, item := range params { + if item == nil || item.Schema == nil { + continue + } + typeName := strings.TrimSpace(item.Schema.Name) + if typeName == "" { + continue + } + item.Tag = ensureTypeNameTag(item.Tag, typeName) + } + return params +} + +func normalizeComponentStateSchemas(component *shapeLoad.Component) { + if component == nil { + return + } + for _, item := range component.Input { + if item != nil { + normalizeSchemaPackage(item.Schema) + } + } + for _, item := range component.Output { + if item != nil { + normalizeSchemaPackage(item.Schema) + } + } + for _, item := range component.Meta { + if item != nil { + normalizeSchemaPackage(item.Schema) + } + } +} + +func normalizeResourceSchemaPackages(resource *view.Resource) { + if resource == nil { + return + } + normalizeParameterSchemas(resource.Parameters) + for _, aView := range resource.Views { + normalizeViewSchemaPackages(aView) + } + for _, item := range resource.Types { + if item == nil { + continue + } + item.Package = normalizedSchemaPackage(item.Package, item.ModulePath) + for _, field := range item.Fields { + if field == nil { + continue + } + normalizeSchemaPackage(field.Schema) + } + } +} + +func backfillResourceColumnDataTypes(resource *view.Resource) { + if resource == nil { + return + } + typeDefs := map[string]*view.TypeDefinition{} + for _, item := range resource.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + typeDefs[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + var visitView func(aView *view.View) + visitView = func(aView *view.View) { + if aView == nil { + return + } + backfillViewColumnDataTypes(aView, typeDefs) + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + visitView(&rel.Of.View) + } + } + for _, aView := range resource.Views { + visitView(aView) + } +} + +func backfillViewColumnDataTypes(aView *view.View, defs map[string]*view.TypeDefinition) { + if aView == nil || len(aView.Columns) == 0 || aView.Schema == nil { + return + } + typeName := strings.TrimSpace(aView.Schema.Name) + if typeName == "" { + return + } + def := defs[strings.ToLower(typeName)] + if def == nil { + return + } + fieldTypes := map[string]string{} + for _, field := range def.Fields { + if field == nil || field.Schema == nil { + continue + } + dataType := strings.TrimSpace(firstNonEmpty(field.Schema.DataType, field.Schema.Name)) + if dataType == "" { + continue + } + for _, key := range []string{ + strings.ToUpper(strings.TrimSpace(field.Name)), + strings.ToUpper(strings.TrimSpace(field.Column)), + strings.ToUpper(strings.TrimSpace(field.FromName)), + } { + if key != "" { + fieldTypes[key] = dataType + } + } + } + for _, column := range aView.Columns { + if column == nil || strings.TrimSpace(column.DataType) != "" { + continue + } + for _, key := range []string{ + strings.ToUpper(strings.TrimSpace(column.Name)), + strings.ToUpper(strings.TrimSpace(column.DatabaseColumn)), + strings.ToUpper(strings.TrimSpace(column.FieldName())), + } { + if dataType := strings.TrimSpace(fieldTypes[key]); dataType != "" { + column.DataType = dataType + break + } + } + } +} + +func canonicalizeResourceTypeDefinitions(resource *view.Resource) { + if resource == nil { + return + } + for _, def := range resource.Types { + canonicalizeTypeDefinition(def) + } +} + +func promoteAnonymousParameterTypeDefinitions(resource *view.Resource) { + if resource == nil { + return + } + existing := map[string]bool{} + for _, def := range resource.Types { + if def == nil || strings.TrimSpace(def.Name) == "" { + continue + } + existing[strings.ToLower(strings.TrimSpace(def.Name))] = true + } + promoted := map[string]string{} + for _, param := range resource.Parameters { + if param == nil || param.Schema == nil { + continue + } + typeName := promotedParameterTypeName(param) + if typeName == "" { + continue + } + key := strings.ToLower(typeName) + if !existing[key] { + def := typeDefinitionFromAnonymousParameter(typeName, param) + if def == nil { + continue + } + resource.Types = append(resource.Types, def) + existing[key] = true + } + promoted[strings.ToLower(strings.TrimSpace(param.Name))] = typeName + rewritePromotedParameterSchema(param, typeName) + } + if len(promoted) == 0 { + return + } + visitResourceParameters(resource, func(param *state.Parameter) { + if param == nil || param.Schema == nil { + return + } + typeName := promoted[strings.ToLower(strings.TrimSpace(param.Name))] + if typeName == "" { + return + } + rewritePromotedParameterSchema(param, typeName) + }) +} + +func promotedParameterTypeName(param *state.Parameter) string { + if param == nil || param.Schema == nil { + return "" + } + if strings.TrimSpace(param.Name) == "" { + return "" + } + if strings.TrimSpace(param.Schema.Name) != "" && !strings.Contains(strings.TrimSpace(param.Schema.Name), "struct {") { + return "" + } + dataType := strings.TrimSpace(param.Schema.DataType) + rType := param.Schema.Type() + if !strings.Contains(dataType, "struct {") { + if rType == nil { + return "" + } + base := rType + for base.Kind() == reflect.Ptr || base.Kind() == reflect.Slice || base.Kind() == reflect.Array { + base = base.Elem() + } + if base.Kind() != reflect.Struct || base.Name() != "" { + return "" + } + } + return state.SanitizeTypeName(strings.TrimSpace(param.Name)) +} + +func typeDefinitionFromAnonymousParameter(typeName string, param *state.Parameter) *view.TypeDefinition { + if param == nil || param.Schema == nil || typeName == "" { + return nil + } + fields := typeDefinitionFieldsFromReflectType(param.Schema.Type()) + if len(fields) == 0 { + return nil + } + for _, field := range fields { + if field != nil { + field.Tag = "" + } + } + def := &view.TypeDefinition{Name: typeName, Fields: dedupeTypeDefinitionFields(fields)} + canonicalizeTypeDefinition(def) + return def +} + +func rewritePromotedParameterSchema(param *state.Parameter, typeName string) { + if param == nil || param.Schema == nil || typeName == "" { + return + } + param.Schema.Name = typeName + param.Schema.DataType = typeName + param.Schema.Package = "" + param.Schema.PackagePath = "" + param.Schema.ModulePath = "" +} + +func visitResourceParameters(resource *view.Resource, visitor func(param *state.Parameter)) { + if resource == nil || visitor == nil { + return + } + for _, param := range resource.Parameters { + visitor(param) + } + var visitView func(aView *view.View) + visitView = func(aView *view.View) { + if aView == nil { + return + } + if aView.Template != nil { + for _, param := range aView.Template.Parameters { + visitor(param) + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + visitView(&rel.Of.View) + } + } + for _, aView := range resource.Views { + visitView(aView) + } +} + +func alignViewParameterSchemasToResourceTypes(aView *view.View, resource *view.Resource) { + if aView == nil || resource == nil { + return + } + typeNames := map[string]string{} + for _, def := range resource.Types { + if def == nil || strings.TrimSpace(def.Name) == "" { + continue + } + typeNames[strings.ToLower(strings.TrimSpace(def.Name))] = strings.TrimSpace(def.Name) + } + var visitView func(current *view.View) + visitView = func(current *view.View) { + if current == nil { + return + } + if current.Template != nil { + for _, param := range current.Template.Parameters { + if param == nil || param.Schema == nil { + continue + } + typeName := typeNames[strings.ToLower(strings.TrimSpace(param.Name))] + if typeName == "" { + continue + } + rewritePromotedParameterSchema(param, typeName) + } + } + for _, rel := range current.With { + if rel == nil || rel.Of == nil { + continue + } + visitView(&rel.Of.View) + } + } + visitView(aView) +} + +func canonicalizeTypeDefinition(def *view.TypeDefinition) { + if def == nil || len(def.Fields) == 0 { + return + } + fields := dedupeTypeDefinitionFields(def.Fields) + if len(fields) == 0 { + return + } + def.DataType = inlineStructDataType(fields) + def.Fields = nil + def.Schema = nil +} + +func dedupeTypeDefinitionFields(fields []*view.Field) []*view.Field { + type keyedField struct { + key string + field *view.Field + } + var ordered []keyedField + index := map[string]int{} + for _, field := range fields { + if field == nil { + continue + } + key := canonicalTypeFieldKey(field) + if key == "" { + continue + } + if pos, ok := index[key]; ok { + merged := mergeTypeFields(ordered[pos].field, field) + if preferTypeField(field, ordered[pos].field) { + ordered[pos].field = merged + } else { + ordered[pos].field = merged + } + continue + } + index[key] = len(ordered) + ordered = append(ordered, keyedField{key: key, field: cloneTypeField(field)}) + } + result := make([]*view.Field, 0, len(ordered)) + for _, item := range ordered { + if item.field != nil { + item.field.Tag = sanitizeTypeFieldTag(item.field.Tag, item.field) + result = append(result, item.field) + } + } + return result +} + +func mergeTypeFields(primary, secondary *view.Field) *view.Field { + result := cloneTypeField(primary) + if result == nil { + return cloneTypeField(secondary) + } + if secondary == nil { + return result + } + if strings.TrimSpace(result.Column) == "" { + result.Column = strings.TrimSpace(secondary.Column) + } + if strings.TrimSpace(result.FromName) == "" { + result.FromName = strings.TrimSpace(secondary.FromName) + } + if strings.TrimSpace(result.Tag) == "" { + result.Tag = strings.TrimSpace(secondary.Tag) + } + if result.Schema == nil && secondary.Schema != nil { + result.Schema = secondary.Schema.Clone() + } + return result +} + +func cloneTypeField(field *view.Field) *view.Field { + if field == nil { + return nil + } + cloned := *field + if field.Schema != nil { + cloned.Schema = field.Schema.Clone() + } + return &cloned +} + +func canonicalTypeFieldKey(field *view.Field) string { + for _, candidate := range []string{ + strings.TrimSpace(field.Column), + strings.TrimSpace(field.FromName), + strings.TrimSpace(field.Name), + } { + if candidate != "" { + return strings.ToUpper(candidate) + } + } + return "" +} + +func preferTypeField(candidate, existing *view.Field) bool { + if existing == nil { + return true + } + candidateScore := typeFieldPreferenceScore(candidate) + existingScore := typeFieldPreferenceScore(existing) + if candidateScore != existingScore { + return candidateScore > existingScore + } + return strings.TrimSpace(candidate.Name) < strings.TrimSpace(existing.Name) +} + +func typeFieldPreferenceScore(field *view.Field) int { + if field == nil { + return -1 + } + score := 0 + name := strings.TrimSpace(field.Name) + if name != "" && name != strings.ToUpper(name) { + score += 10 + } + if strings.TrimSpace(field.Column) == "" { + score += 3 + } + if strings.TrimSpace(field.FromName) == name { + score += 2 + } + if strings.EqualFold(name, "Has") { + score += 5 + } + return score +} + +func inlineStructDataType(fields []*view.Field) string { + parts := make([]string, 0, len(fields)) + for _, field := range fields { + if field == nil || field.Schema == nil { + continue + } + typeName := strings.TrimSpace(firstNonEmpty(field.Schema.DataType, field.Schema.Name)) + if typeName == "" { + continue + } + tag := strings.TrimSpace(stripVeltyTag(field.Tag)) + if tag != "" { + parts = append(parts, fmt.Sprintf(`%s %s %q`, strings.TrimSpace(field.Name), typeName, tag)) + continue + } + parts = append(parts, fmt.Sprintf(`%s %s`, strings.TrimSpace(field.Name), typeName)) + } + return "struct { " + strings.Join(parts, "; ") + " }" +} + +func typeDefinitionFieldsFromReflectType(rType reflect.Type) []*view.Field { + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + result := make([]*view.Field, 0, rType.NumField()) + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if !field.IsExported() { + continue + } + result = append(result, &view.Field{ + Name: field.Name, + Schema: schemaFromReflectType(field.Type), + Tag: string(field.Tag), + FromName: field.Name, + }) + } + return result +} + +func schemaFromReflectType(rType reflect.Type) *state.Schema { + if rType == nil { + return nil + } + schema := state.NewSchema(rType) + if schema == nil { + return nil + } + if schema.Name == "" && schema.DataType == "" { + schema.DataType = rType.String() + if schema.Cardinality == "" { + schema.Cardinality = state.One + } + } + if schema.Cardinality == state.Many && schema.DataType == "" { + schema.DataType = rType.String() + } + return schema +} + +func stripVeltyTag(tag string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + updated, _ := xreflect.RemoveTag(tag, "velty") + return strings.TrimSpace(updated) +} + +func sanitizeTypeFieldTag(tag string, field *view.Field) string { + tag = stripVeltyTag(tag) + if field == nil || field.Schema == nil { + return tag + } + dataType := strings.TrimSpace(firstNonEmpty(field.Schema.DataType, field.Schema.Name)) + if strings.HasPrefix(dataType, "*struct {") || strings.HasPrefix(dataType, "struct {") { + updated, _ := xreflect.RemoveTag(tag, "typeName") + tag = strings.TrimSpace(updated) + } + return tag +} + +func normalizeViewSchemaPackages(aView *view.View) { + if aView == nil { + return + } + normalizeSchemaPackage(aView.Schema) + if aView.Template != nil { + normalizeParameterSchemas(aView.Template.Parameters) + if aView.Template.Summary != nil { + normalizeSchemaPackage(aView.Template.Summary.Schema) + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + normalizeViewSchemaPackages(&rel.Of.View) + } +} + +func normalizeParameterSchemas(params state.Parameters) { + for _, param := range params { + if param == nil { + continue + } + normalizeSchemaPackage(param.Schema) + if param.Output != nil { + normalizeSchemaPackage(param.Output.Schema) + } + } +} + +func normalizeSchemaPackage(schema *state.Schema) { + if schema == nil { + return + } + schema.Package = normalizedSchemaPackage(schema.Package, firstNonEmpty(schema.PackagePath, schema.ModulePath)) + if strings.TrimSpace(schema.PackagePath) == "" && strings.Contains(strings.TrimSpace(schema.Package), "/") { + schema.PackagePath = strings.TrimSpace(schema.Package) + } +} + +func normalizedSchemaPackage(pkg, pkgPath string) string { + pkg = strings.TrimSpace(pkg) + pkgPath = strings.TrimSpace(pkgPath) + if pkgPath == "" && strings.Contains(pkg, "/") { + pkgPath = pkg + } + if strings.Contains(pkg, "/") { + return path.Base(pkg) + } + if pkg == "" && pkgPath != "" { + return path.Base(pkgPath) + } + return pkg +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func alignGeneratedPackageAliases(resource *view.Resource, component *shapeLoad.Component, packageDir, packagePath, packageName string) { + packageDir = strings.TrimSpace(packageDir) + packagePath = strings.TrimSpace(packagePath) + packageName = strings.TrimSpace(packageName) + if packagePath == "" || packageName == "" { + return + } + if component != nil && component.TypeContext != nil { + if packageDir != "" { + packageDir = filepath.ToSlash(filepath.Clean(packageDir)) + } + if strings.TrimSpace(component.TypeContext.PackagePath) == packagePath { + component.TypeContext.PackageName = packageName + if packageDir != "" { + component.TypeContext.PackageDir = packageDir + } + } + if strings.TrimSpace(component.TypeContext.PackagePath) == "" { + component.TypeContext.PackagePath = packagePath + } + if strings.TrimSpace(component.TypeContext.PackageName) == "" { + component.TypeContext.PackageName = packageName + } + if packageDir != "" && strings.TrimSpace(component.TypeContext.PackageDir) == "" { + component.TypeContext.PackageDir = packageDir + } + for _, group := range [][]*plan.State{component.Input, component.Output, component.Meta, component.Async, component.Other} { + for _, item := range group { + if item == nil { + continue + } + alignSchemaPackageAlias(item.Schema, packagePath, packageName) + alignSchemaPackageAlias(item.OutputSchema(), packagePath, packageName) + } + } + } + if resource == nil { + return + } + for _, item := range resource.Parameters { + if item == nil { + continue + } + alignSchemaPackageAlias(item.Schema, packagePath, packageName) + alignSchemaPackageAlias(item.OutputSchema(), packagePath, packageName) + } + for _, aView := range resource.Views { + if aView == nil { + continue + } + alignSchemaPackageAlias(aView.Schema, packagePath, packageName) + if aView.Template != nil { + alignSchemaPackageAlias(aView.Template.Schema, packagePath, packageName) + alignParameterPackages(aView.Template.Parameters, packagePath, packageName) + if aView.Template.Summary != nil { + alignSchemaPackageAlias(aView.Template.Summary.Schema, packagePath, packageName) + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + alignSchemaPackageAlias(rel.Of.Schema, packagePath, packageName) + alignSchemaPackageAlias(rel.Of.View.Schema, packagePath, packageName) + if rel.Of.View.Template != nil && rel.Of.View.Template.Summary != nil { + alignSchemaPackageAlias(rel.Of.View.Template.Summary.Schema, packagePath, packageName) + } + } + } + for _, item := range resource.Types { + if item == nil { + continue + } + if firstNonEmpty(strings.TrimSpace(item.ModulePath), schemaPackagePath(item.Schema)) == packagePath { + item.Package = packageName + } + alignSchemaPackageAlias(item.Schema, packagePath, packageName) + for _, field := range item.Fields { + if field == nil { + continue + } + alignSchemaPackageAlias(field.Schema, packagePath, packageName) + } + } +} + +func alignParameterPackages(params state.Parameters, packagePath, packageName string) { + for _, item := range params { + if item == nil { + continue + } + alignSchemaPackageAlias(item.Schema, packagePath, packageName) + alignSchemaPackageAlias(item.OutputSchema(), packagePath, packageName) + } +} + +func alignSchemaPackageAlias(schema *state.Schema, packagePath, packageName string) { + if schema == nil { + return + } + if schemaPackagePath(schema) == packagePath { + schema.Package = packageName + qualifyGeneratedSchemaDataType(schema, packageName) + } +} + +func schemaPackagePath(schema *state.Schema) string { + if schema == nil { + return "" + } + return firstNonEmpty(strings.TrimSpace(schema.PackagePath), strings.TrimSpace(schema.ModulePath)) +} + +func qualifyGeneratedSchemaDataType(schema *state.Schema, packageName string) { + if schema == nil { + return + } + packageName = strings.TrimSpace(packageName) + if packageName == "" { + return + } + dataType := strings.TrimSpace(schema.DataType) + typeName := strings.TrimLeft(strings.TrimSpace(schema.Name), "*") + if dataType == "" || typeName == "" || strings.Contains(dataType, ".") { + return + } + replacements := map[string]string{ + typeName: packageName + "." + typeName, + "*" + typeName: "*" + packageName + "." + typeName, + "[]" + typeName: "[]" + packageName + "." + typeName, + "[]*" + typeName: "[]*" + packageName + "." + typeName, + } + if qualified, ok := replacements[dataType]; ok { + schema.DataType = qualified + } +} + +func ensureTypeNameTag(tag string, typeName string) string { + typeName = strings.TrimSpace(typeName) + if typeName == "" { + return strings.TrimSpace(tag) + } + tag = strings.TrimSpace(tag) + if strings.Contains(tag, `typeName:"`) { + return tag + } + if tag == "" { + return fmt.Sprintf(`typeName:"%s"`, typeName) + } + return tag + ` typeName:"` + typeName + `"` +} + +func normalizeConnectorRefsYAML(data []byte) ([]byte, error) { + var node yaml.Node + if err := yaml.Unmarshal(data, &node); err != nil { + return nil, err + } + rewriteConnectorNode(&node) + return yaml.Marshal(&node) +} + +func normalizeRouteViewRefsYAML(data []byte) ([]byte, error) { + var node yaml.Node + if err := yaml.Unmarshal(data, &node); err != nil { + return nil, err + } + rewriteRouteViewNode(&node) + return yaml.Marshal(&node) +} + +func normalizeRouteComponentEmbeddingYAML(data []byte) ([]byte, error) { + var node yaml.Node + if err := yaml.Unmarshal(data, &node); err != nil { + return nil, err + } + if len(node.Content) == 0 || node.Content[0] == nil { + return data, nil + } + root := node.Content[0] + routes := yamlMapLookup(root, "Routes") + if routes == nil || routes.Kind != yaml.SequenceNode { + return data, nil + } + for _, item := range routes.Content { + flattenRouteComponentNode(item) + } + return yaml.Marshal(&node) +} + +func ensureSharedResourceRefsYAML(data []byte, refs []string) ([]byte, error) { + if len(refs) == 0 { + return data, nil + } + var node yaml.Node + if err := yaml.Unmarshal(data, &node); err != nil { + return nil, err + } + if len(node.Content) == 0 || node.Content[0] == nil || node.Content[0].Kind != yaml.MappingNode { + return data, nil + } + root := node.Content[0] + if existing := yamlMapLookup(root, "With"); existing != nil && existing.Kind == yaml.SequenceNode && len(existing.Content) > 0 { + return data, nil + } + seq := &yaml.Node{Kind: yaml.SequenceNode} + for _, ref := range refs { + if strings.TrimSpace(ref) == "" { + continue + } + seq.Content = append(seq.Content, &yaml.Node{Kind: yaml.ScalarNode, Value: strings.TrimSpace(ref), Tag: "!!str"}) + } + if len(seq.Content) == 0 { + return data, nil + } + root.Content = append(root.Content, + &yaml.Node{Kind: yaml.ScalarNode, Value: "With", Tag: "!!str"}, + seq, + ) + return yaml.Marshal(&node) +} + +func rewriteConnectorNode(node *yaml.Node) { + if node == nil { + return + } + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + rewriteConnectorNode(child) + } + case yaml.MappingNode: + for i := 0; i < len(node.Content)-1; i += 2 { + key := node.Content[i] + value := node.Content[i+1] + switch strings.ToLower(strings.TrimSpace(key.Value)) { + case "connector": + if ref := connectorRefFromYAMLNode(value); ref != "" { + node.Content[i+1] = connectorRefYAMLNode(ref) + value = node.Content[i+1] + } + case "connectors": + if value.Kind == yaml.SequenceNode { + for j, item := range value.Content { + if ref := connectorRefFromYAMLNode(item); ref != "" { + value.Content[j] = connectorRefYAMLNode(ref) + } + } + } + } + rewriteConnectorNode(value) + } + case yaml.SequenceNode: + for _, child := range node.Content { + rewriteConnectorNode(child) + } + } +} + +func rewriteRouteViewNode(node *yaml.Node) { + if node == nil { + return + } + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + rewriteRouteViewNode(child) + } + case yaml.MappingNode: + for i := 0; i < len(node.Content)-1; i += 2 { + key := node.Content[i] + value := node.Content[i+1] + if strings.EqualFold(strings.TrimSpace(key.Value), "view") && value.Kind == yaml.MappingNode { + if ref := routeViewRefFromYAMLNode(value); ref != "" { + node.Content[i+1] = &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Value: "Ref", Tag: "!!str"}, + {Kind: yaml.ScalarNode, Value: ref, Tag: "!!str"}, + }, + } + value = node.Content[i+1] + } + } + rewriteRouteViewNode(value) + } + case yaml.SequenceNode: + for _, child := range node.Content { + rewriteRouteViewNode(child) + } + } +} + +func flattenRouteComponentNode(node *yaml.Node) { + if node == nil || node.Kind != yaml.MappingNode { + return + } + for _, key := range []string{"meta", "path", "contract"} { + embedded := yamlMapLookup(node, key) + if embedded == nil || embedded.Kind != yaml.MappingNode { + continue + } + removeYAMLMapKey(node, key) + node.Content = append(node.Content, embedded.Content...) + } +} + +func routeViewRefFromYAMLNode(node *yaml.Node) string { + if node == nil || node.Kind != yaml.MappingNode { + return "" + } + if ref := yamlMapLookup(node, "Ref"); ref != nil && strings.TrimSpace(ref.Value) != "" { + return strings.TrimSpace(ref.Value) + } + reference := yamlMapLookup(node, "reference") + if reference == nil { + return "" + } + ref := yamlMapLookup(reference, "ref") + if ref == nil { + return "" + } + return strings.TrimSpace(ref.Value) +} + +func connectorRefFromYAMLNode(node *yaml.Node) string { + if node == nil || node.Kind != yaml.MappingNode { + return "" + } + if ref := yamlMapLookup(node, "ref"); ref != nil && strings.TrimSpace(ref.Value) != "" { + return strings.TrimSpace(ref.Value) + } + connection := yamlMapLookup(node, "connection") + if connection == nil { + return "" + } + dbConfig := yamlMapLookup(connection, "dbconfig") + if dbConfig == nil { + return "" + } + reference := yamlMapLookup(dbConfig, "reference") + if reference == nil { + return "" + } + ref := yamlMapLookup(reference, "ref") + if ref == nil { + return "" + } + return strings.TrimSpace(ref.Value) +} + +func connectorRefYAMLNode(ref string) *yaml.Node { + return &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Value: "ref", Tag: "!!str"}, + {Kind: yaml.ScalarNode, Value: ref, Tag: "!!str"}, + }, + } +} + +func yamlMapLookup(node *yaml.Node, key string) *yaml.Node { + if node == nil || node.Kind != yaml.MappingNode { + return nil + } + for i := 0; i < len(node.Content)-1; i += 2 { + if strings.EqualFold(strings.TrimSpace(node.Content[i].Value), key) { + return node.Content[i+1] + } + } + return nil +} + +func removeYAMLMapKey(node *yaml.Node, key string) { + if node == nil || node.Kind != yaml.MappingNode { + return + } + filtered := make([]*yaml.Node, 0, len(node.Content)) + for i := 0; i < len(node.Content)-1; i += 2 { + if strings.EqualFold(strings.TrimSpace(node.Content[i].Value), key) { + continue + } + filtered = append(filtered, node.Content[i], node.Content[i+1]) + } + node.Content = filtered +} diff --git a/cmd/command/transcribe_test.go b/cmd/command/transcribe_test.go new file mode 100644 index 000000000..73d8b5d4a --- /dev/null +++ b/cmd/command/transcribe_test.go @@ -0,0 +1,507 @@ +package command + +import ( + "context" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + extension "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state" + "gopkg.in/yaml.v3" +) + +func TestPatchBasicOne_LoadedComponentHasMutableExecHelpers(t *testing.T) { + source := filepath.Join("..", "..", "e2e", "v1", "dql", "dev", "events", "patch_basic_one.dql") + data, err := os.ReadFile(source) + require.NoError(t, err) + + planned, err := shapeCompile.New().Compile(context.Background(), &shape.Source{ + Name: "patch_basic_one", + Path: source, + DQL: string(data), + }) + require.NoError(t, err) + + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + + root := lookupNamedView(artifact.Resource, component.RootView) + require.NotNil(t, root) + assert.Equal(t, view.ModeExec, root.Mode) + require.NotNil(t, root.Template) + assert.True(t, root.Template.UseParameterStateType) + require.NotNil(t, root.Template.Parameters.Lookup("CurFoosId")) + require.NotNil(t, root.Template.Parameters.Lookup("CurFoos")) + assert.Equal(t, state.Many, root.Template.Parameters.Lookup("CurFoos").Schema.Cardinality) + + input := component.InputParameters() + require.Nil(t, input.Lookup("CurFoosId")) + require.Nil(t, input.Lookup("CurFoos")) + + curFoos, err := artifact.Resource.View("CurFoos") + require.NoError(t, err) + require.NotNil(t, curFoos) + require.NotNil(t, curFoos.Template) + assert.Equal(t, "foos/cur_foos.sql", curFoos.Template.SourceURL) +} + +func TestPatchBasicOne_LoadedComponentHasMutableExecHelpers_WithTypeContextPackages(t *testing.T) { + source := filepath.Join("..", "..", "e2e", "v1", "dql", "dev", "events", "patch_basic_one.dql") + data, err := os.ReadFile(source) + require.NoError(t, err) + + planned, err := shapeCompile.New().Compile(context.Background(), &shape.Source{ + Name: "patch_basic_one", + Path: source, + DQL: string(data), + }, transcribeCompileOptions(&options.Transcribe{ + Project: filepath.Join("..", "..", "e2e", "v1"), + Module: filepath.Join("..", "..", "e2e", "v1"), + TypeOutput: filepath.Join("..", "..", "e2e", "v1", "shape"), + Namespace: "dev/basic/foos", + })...) + require.NoError(t, err) + + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + + root := lookupNamedView(artifact.Resource, component.RootView) + require.NotNil(t, root) + require.NotNil(t, root.Template) + require.NotNil(t, root.Template.Parameters.Lookup("CurFoos")) + assert.Equal(t, state.Many, root.Template.Parameters.Lookup("CurFoos").Schema.Cardinality) + + curFoos, err := artifact.Resource.View("CurFoos") + require.NoError(t, err) + require.NotNil(t, curFoos) + require.NotNil(t, curFoos.Template) + require.True(t, curFoos.Template.DeclaredParametersOnly) + require.NotNil(t, curFoos.Template.Parameters.Lookup("CurFoosId")) + require.Nil(t, curFoos.Template.Parameters.Lookup("Foos")) +} + +func TestTranscribeSharedResourceRefs_IncludesConnectors(t *testing.T) { + resource := &view.Resource{ + Connectors: []*view.Connector{view.NewRefConnector("dev")}, + } + + refs := transcribeSharedResourceRefs(resource) + require.Equal(t, []string{view.ResourceConnectors}, refs) +} + +func TestEnsureSharedResourceRefsYAML_AppendsWith(t *testing.T) { + data, err := ensureSharedResourceRefsYAML([]byte("Resource: {}\nRoutes: []\n"), []string{view.ResourceConnectors}) + require.NoError(t, err) + assert.True(t, strings.Contains(string(data), "With:\n - connectors")) +} + +func TestNormalizeParameterTypeNameTags_AppendsTypeName(t *testing.T) { + params := state.Parameters{ + &state.Parameter{ + Name: "Foos", + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "FoosView"}, + }, + } + + normalized := normalizeParameterTypeNameTags(params) + require.Len(t, normalized, 1) + assert.Equal(t, `anonymous:"true" typeName:"FoosView"`, normalized[0].Tag) +} + +func TestPreserveTemplateParameters_AppendsMutableHelpers(t *testing.T) { + aView := &view.View{ + Name: "foos", + Template: view.NewTemplate("SELECT 1", view.WithTemplateParameters( + &state.Parameter{Name: "Foos", In: state.NewBodyLocation(""), Schema: &state.Schema{Name: "FoosView"}}, + )), + } + + params := state.Parameters{ + &state.Parameter{Name: "Foos", In: state.NewBodyLocation(""), Schema: &state.Schema{Name: "FoosView"}}, + &state.Parameter{Name: "CurFoosId", In: state.NewParameterLocation("Foos"), Schema: &state.Schema{DataType: "int"}}, + &state.Parameter{Name: "CurFoos", In: state.NewViewLocation("CurFoos"), Schema: &state.Schema{Name: "FoosView"}}, + &state.Parameter{Name: "Meta", In: state.NewOutputLocation("summary"), Schema: &state.Schema{Name: "MetaView"}}, + } + + preserveTemplateParameters(aView, params) + + require.NotNil(t, aView.Template.Parameters.Lookup("Foos")) + require.NotNil(t, aView.Template.Parameters.Lookup("CurFoosId")) + require.NotNil(t, aView.Template.Parameters.Lookup("CurFoos")) + require.Nil(t, aView.Template.Parameters.Lookup("Meta")) +} + +func TestPreserveTemplateParameters_SkipsDeclaredOnlyTemplate(t *testing.T) { + aView := &view.View{ + Name: "CurFoos", + Template: view.NewTemplate("SELECT 1", + view.WithTemplateParameters( + &state.Parameter{Name: "CurFoosId", In: state.NewParameterLocation("Foos"), Schema: &state.Schema{DataType: "int"}}, + ), + view.WithTemplateDeclaredParametersOnly(true), + ), + } + + params := state.Parameters{ + &state.Parameter{Name: "Foos", In: state.NewBodyLocation(""), Schema: &state.Schema{Name: "FoosView"}}, + } + + preserveTemplateParameters(aView, params) + + require.NotNil(t, aView.Template.Parameters.Lookup("CurFoosId")) + require.Nil(t, aView.Template.Parameters.Lookup("Foos")) +} + +func TestPrepareResourceForTranscribeCodegen_DeclaredOnlyTemplateKeepsOnlyUsedParams(t *testing.T) { + resource := &view.Resource{ + Parameters: state.Parameters{ + &state.Parameter{Name: "Foos", In: state.NewBodyLocation(""), Schema: &state.Schema{Name: "FoosView"}}, + &state.Parameter{Name: "CurFoosId", In: state.NewParameterLocation("Foos"), Schema: &state.Schema{DataType: "int"}}, + }, + Views: []*view.View{ + { + Name: "foos", + Template: view.NewTemplate("SELECT 1", view.WithTemplateParameters( + &state.Parameter{Name: "Foos", In: state.NewBodyLocation(""), Schema: &state.Schema{Name: "FoosView"}}, + )), + }, + { + Name: "CurFoos", + Template: view.NewTemplate(`SELECT * FROM FOOS WHERE $criteria.In("ID", $CurFoosId.Values)`, + view.WithTemplateParameters( + &state.Parameter{Name: "CurFoosId", In: state.NewParameterLocation("Foos"), Schema: &state.Schema{DataType: "int"}}, + ), + view.WithTemplateDeclaredParametersOnly(true), + ), + }, + }, + } + component := &shapeLoad.Component{RootView: "foos"} + + prepareResourceForTranscribeCodegen(resource, component) + + curFoos := lookupNamedView(resource, "CurFoos") + require.NotNil(t, curFoos) + require.NotNil(t, curFoos.Template) + require.NotNil(t, curFoos.Template.Parameters.Lookup("CurFoosId")) + require.Nil(t, curFoos.Template.Parameters.Lookup("Foos")) +} + +func TestDependentTemplateParameters_AppendsParentSourceParameter(t *testing.T) { + params := state.Parameters{ + &state.Parameter{Name: "CurFoosId", In: state.NewParameterLocation("Foos"), Schema: &state.Schema{DataType: "int"}}, + } + resourceParams := state.Parameters{ + &state.Parameter{Name: "Foos", In: state.NewBodyLocation(""), Schema: &state.Schema{Name: "FoosView"}}, + &state.Parameter{Name: "CurFoosId", In: state.NewParameterLocation("Foos"), Schema: &state.Schema{DataType: "int"}}, + } + + deps := dependentTemplateParameters(params, resourceParams) + + require.Len(t, deps, 1) + require.Equal(t, "Foos", deps[0].Name) +} + +func TestAlignGeneratedPackageAliases_UsesGeneratedPackageName(t *testing.T) { + const pkgPath = "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one" + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "foos", + Schema: &state.Schema{ + Package: "foos", + PackagePath: pkgPath, + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.Many, + }, + Template: view.NewTemplate("SELECT 1", view.WithTemplateParameters( + &state.Parameter{ + Name: "Foos", + Schema: &state.Schema{Package: "foos", PackagePath: pkgPath, Name: "FoosView", DataType: "*FoosView", Cardinality: state.One}, + }, + )), + }, + }, + Types: []*view.TypeDefinition{ + { + Name: "FoosView", + Package: "foos", + ModulePath: pkgPath, + Schema: &state.Schema{Package: "foos", PackagePath: pkgPath, Name: "FoosView", DataType: "*FoosView", Cardinality: state.Many}, + }, + }, + Parameters: state.Parameters{ + &state.Parameter{ + Name: "Foos", + Schema: &state.Schema{Package: "foos", PackagePath: pkgPath, Name: "FoosView", DataType: "*FoosView", Cardinality: state.One}, + }, + }, + } + component := &shapeLoad.Component{ + TypeContext: &typectx.Context{PackageName: "foos", PackagePath: pkgPath}, + Input: []*plan.State{ + {Parameter: state.Parameter{Name: "Foos", Schema: &state.Schema{Package: "foos", PackagePath: pkgPath, Name: "FoosView", DataType: "*FoosView", Cardinality: state.One}}}, + }, + } + + alignGeneratedPackageAliases(resource, component, filepath.Join("..", "..", "e2e", "v1", "shape", "dev", "events", "patch_basic_one"), pkgPath, "patch_basic_one") + + require.Equal(t, "patch_basic_one", component.TypeContext.PackageName) + require.Equal(t, filepath.ToSlash(filepath.Clean(filepath.Join("..", "..", "e2e", "v1", "shape", "dev", "events", "patch_basic_one"))), component.TypeContext.PackageDir) + require.Equal(t, "patch_basic_one", component.Input[0].Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", component.Input[0].Schema.DataType) + require.Equal(t, "patch_basic_one", resource.Views[0].Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", resource.Views[0].Schema.DataType) + require.Equal(t, "patch_basic_one", resource.Views[0].Template.Parameters[0].Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", resource.Views[0].Template.Parameters[0].Schema.DataType) + require.Equal(t, "patch_basic_one", resource.Parameters[0].Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", resource.Parameters[0].Schema.DataType) + require.Equal(t, "patch_basic_one", resource.Types[0].Package) + require.Equal(t, "patch_basic_one", resource.Types[0].Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", resource.Types[0].Schema.DataType) +} + +func TestGenerateTranscribeTypes_RealignsGeneratedPackageAlias(t *testing.T) { + source := filepath.Join("..", "..", "e2e", "v1", "dql", "dev", "events", "patch_basic_one.dql") + data, err := os.ReadFile(source) + require.NoError(t, err) + + planned, err := shapeCompile.New().Compile(context.Background(), &shape.Source{ + Name: "patch_basic_one", + Path: source, + DQL: string(data), + Connector: "dev", + }, transcribeCompileOptions(&options.Transcribe{ + Project: filepath.Join("..", "..", "e2e", "v1"), + Module: filepath.Join("..", "..", "e2e", "v1"), + TypeOutput: filepath.Join("..", "..", "e2e", "v1", "shape"), + Namespace: "dev/basic/foos", + })...) + require.NoError(t, err) + + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + + svc := &Service{} + result, err := svc.generateTranscribeTypes(source, string(data), &options.Transcribe{ + Project: filepath.Join("..", "..", "e2e", "v1"), + Module: filepath.Join("..", "..", "e2e", "v1"), + TypeOutput: filepath.Join("..", "..", "e2e", "v1", "shape"), + Namespace: "dev/basic/foos", + }, artifact.Resource, component) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "patch_basic_one", result.PackageName) + + alignGeneratedPackageAliases(artifact.Resource, component, result.PackageDir, result.PackagePath, result.PackageName) + + root := lookupNamedView(artifact.Resource, component.RootView) + require.NotNil(t, root) + require.Equal(t, result.PackageName, root.Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", root.Schema.DataType) + require.Equal(t, result.PackageName, artifact.Resource.Parameters.Lookup("Foos").Schema.Package) + require.Equal(t, "*patch_basic_one.FoosView", artifact.Resource.Parameters.Lookup("Foos").Schema.DataType) +} + +func TestTranscribe_PatchBasicOneRouteYAMLUsesGeneratedPackageName(t *testing.T) { + cwd, err := os.Getwd() + require.NoError(t, err) + repoRoot := filepath.Clean(filepath.Join(cwd, "..", "..")) + project := filepath.Join(repoRoot, "e2e", "v1") + tempRepo := t.TempDir() + + svc := New() + err = svc.Transcribe(context.Background(), &options.Options{ + Transcribe: &options.Transcribe{ + Source: []string{filepath.Join(project, "dql", "dev", "events", "patch_basic_one.dql")}, + Repository: tempRepo, + Project: project, + Module: project, + TypeOutput: filepath.Join(project, "shape"), + Namespace: "dev/basic/foos", + APIPrefix: "/v1/api/shape", + }, + }) + require.NoError(t, err) + + routeYAML := filepath.Join(tempRepo, "Datly", "routes", "patch_basic_one.yaml") + data, err := os.ReadFile(routeYAML) + require.NoError(t, err) + text := string(data) + require.Contains(t, text, "Package: patch_basic_one") + require.Contains(t, text, "DataType: '*patch_basic_one.FoosView'") + require.Contains(t, text, "CaseFormat: lc") + require.NotContains(t, text, "Package: foos\n") +} + +func TestTranscribe_PatchBasicOneRouteYAMLBuildsNamedTemplateState(t *testing.T) { + cwd, err := os.Getwd() + require.NoError(t, err) + repoRoot := filepath.Clean(filepath.Join(cwd, "..", "..")) + project := filepath.Join(repoRoot, "e2e", "v1") + tempRepo := t.TempDir() + + svc := New() + err = svc.Transcribe(context.Background(), &options.Options{ + Transcribe: &options.Transcribe{ + Source: []string{filepath.Join(project, "dql", "dev", "events", "patch_basic_one.dql")}, + Repository: tempRepo, + Project: project, + Module: project, + TypeOutput: filepath.Join(project, "shape"), + Namespace: "dev/basic/foos", + APIPrefix: "/v1/api/shape", + }, + }) + require.NoError(t, err) + + routeYAML := filepath.Join(tempRepo, "Datly", "routes", "patch_basic_one.yaml") + data, err := os.ReadFile(routeYAML) + require.NoError(t, err) + + payload := &shapeRuleFile{} + require.NoError(t, yaml.Unmarshal(data, payload)) + require.NotNil(t, payload.Resource) + payload.Resource.Connectors = []*view.Connector{view.NewConnector("dev", "sqlite3", "file::memory:?cache=shared")} + payload.Resource.SetTypes(extension.Config.Types) + require.NoError(t, payload.Resource.Init(context.Background(), payload.Resource.TypeRegistry(), extension.Config.Codecs, nil, nil, extension.Config.Predicates)) + + root := lookupNamedView(payload.Resource, "foos") + require.NotNil(t, root) + require.NotNil(t, root.Template) + require.NotNil(t, root.Template.StateType()) + + rType := root.Template.StateType().Type() + field, ok := rType.FieldByName("Foos") + require.True(t, ok) + require.Equal(t, "*patch_basic_one.FoosView", field.Type.String()) +} + +func TestTranscribe_PatchBasicOneRouteYAMLPreservesNamedHelperParamTypes(t *testing.T) { + cwd, err := os.Getwd() + require.NoError(t, err) + repoRoot := filepath.Clean(filepath.Join(cwd, "..", "..")) + project := filepath.Join(repoRoot, "e2e", "v1") + tempRepo := t.TempDir() + + svc := New() + err = svc.Transcribe(context.Background(), &options.Options{ + Transcribe: &options.Transcribe{ + Source: []string{filepath.Join(project, "dql", "dev", "events", "patch_basic_one.dql")}, + Repository: tempRepo, + Project: project, + Module: project, + TypeOutput: filepath.Join(project, "shape"), + Namespace: "dev/basic/foos", + APIPrefix: "/v1/api/shape", + }, + }) + require.NoError(t, err) + + routeYAML := filepath.Join(tempRepo, "Datly", "routes", "patch_basic_one.yaml") + data, err := os.ReadFile(routeYAML) + require.NoError(t, err) + + payload := &shapeRuleFile{} + require.NoError(t, yaml.Unmarshal(data, payload)) + require.NotNil(t, payload.Resource) + + curFoosID := payload.Resource.Parameters.Lookup("CurFoosId") + require.NotNil(t, curFoosID) + require.NotNil(t, curFoosID.Schema) + require.Equal(t, "*patch_basic_one.FoosView", curFoosID.Schema.DataType) + + require.NotNil(t, curFoosID.Output) + require.NotNil(t, curFoosID.Output.Schema) + require.Equal(t, `*struct { Values []int "json:\",omitempty\"" }`, curFoosID.Output.Schema.DataType) + + curFoos := lookupNamedView(payload.Resource, "CurFoos") + require.NotNil(t, curFoos) + require.NotNil(t, curFoos.Template) + curFoosParam := curFoos.Template.Parameters.Lookup("CurFoosId") + require.NotNil(t, curFoosParam) + require.NotNil(t, curFoosParam.Schema) + require.Equal(t, "*patch_basic_one.FoosView", curFoosParam.Schema.DataType) +} + +func TestGenerateTranscribeTypes_MetaFormatPreservesChildSummaryType(t *testing.T) { + source := filepath.Join("..", "..", "e2e", "v1", "dql", "dev", "vendorsrv", "meta_format.dql") + data, err := os.ReadFile(source) + require.NoError(t, err) + + project := filepath.Join("..", "..", "e2e", "v1") + shapeOutput := filepath.Join(project, "shape") + transcribeOpts := &options.Transcribe{ + Project: project, + Module: project, + TypeOutput: shapeOutput, + Namespace: "dev/vendor/meta-format", + } + transcribeOpts.Connectors = []string{"dev|mysql|root:dev@tcp(localhost:3306)/dev?parseTime=true"} + + planned, err := shapeCompile.New().Compile(context.Background(), &shape.Source{ + Name: "meta_format", + Path: source, + DQL: string(data), + Connector: "dev", + }, transcribeCompileOptions(transcribeOpts)...) + require.NoError(t, err) + + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + + applyConnectorsToResource(artifact.Resource, transcribeOpts.Connectors) + discoverColumns(context.Background(), artifact.Resource) + prepareResourceForTranscribeCodegen(artifact.Resource, component) + + products := lookupNamedView(artifact.Resource, "products") + require.NotNil(t, products) + require.NotNil(t, products.Template) + require.NotNil(t, products.Template.Summary) + require.NotNil(t, products.Template.Summary.Schema) + summaryType := products.Template.Summary.Schema.Type() + require.NotNil(t, summaryType) + if summaryType.Kind() == reflect.Ptr { + summaryType = summaryType.Elem() + } + field, ok := summaryType.FieldByName("VendorId") + require.True(t, ok) + require.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) + + svc := &Service{} + result, err := svc.generateTranscribeTypes(source, string(data), transcribeOpts, artifact.Resource, component) + require.NoError(t, err) + require.NotNil(t, result) + + outputSource, err := os.ReadFile(result.OutputFilePath) + require.NoError(t, err) + assert.Contains(t, string(outputSource), `type ProductsMetaView struct {`) + assert.Contains(t, string(outputSource), `VendorId *int`) + assert.NotContains(t, string(outputSource), `VendorId string`) +} diff --git a/cmd/command/translate_shape.go b/cmd/command/translate_shape.go index fa12d8f03..3d6ddc205 100644 --- a/cmd/command/translate_shape.go +++ b/cmd/command/translate_shape.go @@ -70,10 +70,11 @@ func (s *Service) translateShape(ctx context.Context, opts *options.Options) err type shapeRuleFile struct { Resource *view.Resource `yaml:"Resource,omitempty"` Routes []*repository.Component `yaml:"Routes,omitempty"` + With []string `yaml:"With,omitempty"` TypeContext any `yaml:"TypeContext,omitempty"` } -func (s *Service) persistShapeRoute(ctx context.Context, opts *options.Options, sourceURL, dql string, resource *view.Resource, component *shapeLoad.Component) error { +func (s *Service) persistShapeRoute(ctx context.Context, opts *options.Options, sourceURL, _ string, resource *view.Resource, component *shapeLoad.Component) error { rule := opts.Rule() routeYAML, routeRoot, relDir, stem, err := routePathForShape(rule, opts.Repository().RepositoryURL, sourceURL) if err != nil { @@ -105,18 +106,9 @@ func (s *Service) persistShapeRoute(ctx context.Context, opts *options.Options, if rootView == "" && resource != nil && len(resource.Views) > 0 && resource.Views[0] != nil { rootView = resource.Views[0].Name } - method, uri := parseShapeRulePath(dql, rule.RuleName(), opts.Repository().APIPrefix) - // Gap 3: RouteDirective overrides method/URI when explicitly declared in DQL. - if component != nil && component.Directives != nil && component.Directives.Route != nil { - rd := component.Directives.Route - if u := strings.TrimSpace(rd.URI); u != "" { - uri = u - } - if len(rd.Methods) > 0 { - if m := strings.TrimSpace(strings.ToUpper(rd.Methods[0])); m != "" { - method = m - } - } + method, uri, err := shapeComponentPath(component) + if err != nil { + return err } route := &repository.Component{ Path: contract.Path{ @@ -172,6 +164,27 @@ func (s *Service) persistShapeRoute(ctx context.Context, opts *options.Options, return nil } +func shapeComponentPath(component *shapeLoad.Component) (string, string, error) { + if component == nil { + return "", "", fmt.Errorf("shape component was nil") + } + method := strings.TrimSpace(strings.ToUpper(component.Method)) + uri := strings.TrimSpace(component.URI) + if method == "" && len(component.ComponentRoutes) > 0 && component.ComponentRoutes[0] != nil { + method = strings.TrimSpace(strings.ToUpper(component.ComponentRoutes[0].Method)) + } + if uri == "" && len(component.ComponentRoutes) > 0 && component.ComponentRoutes[0] != nil { + uri = strings.TrimSpace(component.ComponentRoutes[0].RoutePath) + } + if method == "" { + method = "GET" + } + if uri == "" { + return "", "", fmt.Errorf("shape component route URI was empty") + } + return method, uri, nil +} + func routePathForShape(rule *options.Rule, repoURL, sourceURL string) (routeYAML string, routeRoot string, relDir string, stem string, err error) { sourcePath := filepath.Clean(url.Path(sourceURL)) basePath := filepath.Clean(rule.BaseRuleURL()) diff --git a/cmd/command/translate_shape_ir.go b/cmd/command/translate_shape_ir.go index c2f37155f..8e51812bf 100644 --- a/cmd/command/translate_shape_ir.go +++ b/cmd/command/translate_shape_ir.go @@ -92,8 +92,7 @@ func (s *Service) translateShapeIR(ctx context.Context, opts *options.Options) e return nil } -func buildShapeRulePayload(opts *options.Options, dql string, resource *view.Resource, component *shapeLoad.Component) (*shapeRuleFile, error) { - rule := opts.Rule() +func buildShapeRulePayload(_ *options.Options, _ string, resource *view.Resource, component *shapeLoad.Component) (*shapeRuleFile, error) { rootView := "" if component != nil { rootView = strings.TrimSpace(component.RootView) @@ -101,18 +100,9 @@ func buildShapeRulePayload(opts *options.Options, dql string, resource *view.Res if rootView == "" && resource != nil && len(resource.Views) > 0 && resource.Views[0] != nil { rootView = resource.Views[0].Name } - method, uri := parseShapeRulePath(dql, rule.RuleName(), opts.Repository().APIPrefix) - // Gap 3: RouteDirective overrides method/URI when explicitly declared in DQL. - if component != nil && component.Directives != nil && component.Directives.Route != nil { - rd := component.Directives.Route - if u := strings.TrimSpace(rd.URI); u != "" { - uri = u - } - if len(rd.Methods) > 0 { - if m := strings.TrimSpace(strings.ToUpper(rd.Methods[0])); m != "" { - method = m - } - } + method, uri, err := shapeComponentPath(component) + if err != nil { + return nil, err } route := &repository.Component{ Path: contract.Path{ diff --git a/cmd/command/validate.go b/cmd/command/validate.go new file mode 100644 index 000000000..9ce7657cb --- /dev/null +++ b/cmd/command/validate.go @@ -0,0 +1,181 @@ +package command + +import ( + "context" + "fmt" + "path" + "path/filepath" + "sort" + "strings" + + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + shapevalidate "github.com/viant/datly/repository/shape/validate" +) + +func (s *Service) Validate(ctx context.Context, opts *options.Options) error { + validate := opts.Validate + if validate == nil { + return fmt.Errorf("validate options not set") + } + compiler := shapeCompile.New() + loader := shapeLoad.New() + var validated []string + for _, sourceURL := range validate.Source { + dql, err := s.readSource(ctx, sourceURL) + if err != nil { + return fmt.Errorf("failed to read %s: %w", sourceURL, err) + } + shapeSource := &shape.Source{ + Name: strings.TrimSuffix(filepath.Base(url.Path(sourceURL)), filepath.Ext(sourceURL)), + Path: url.Path(sourceURL), + DQL: strings.TrimSpace(dql), + Connector: validateDefaultConnectorName(validate), + } + planResult, err := compiler.Compile(ctx, shapeSource, validateCompileOptions(validate)...) + if err != nil { + return fmt.Errorf("validate %s: %w", sourceURL, err) + } + if err = validateDiagnostics(sourceURL, planResult); err != nil { + return err + } + if err = validatePlannedSQLAssets(ctx, s, shapeSource, planResult); err != nil { + return fmt.Errorf("validate %s: %w", sourceURL, err) + } + resourceArtifacts, err := loader.LoadResource(ctx, planResult, shape.WithLoadTypeContextPackages(true)) + if err != nil { + return fmt.Errorf("validate %s: %w", sourceURL, err) + } + if err = shapevalidate.ValidateRelations(resourceArtifacts.Resource); err != nil { + return fmt.Errorf("validate %s: %w", sourceURL, err) + } + validated = append(validated, filepath.Clean(url.Path(sourceURL))) + } + sort.Strings(validated) + for _, item := range validated { + fmt.Printf("validated %s\n", item) + } + return nil +} + +func validateDefaultConnectorName(v *options.Validate) string { + if v == nil || len(v.Connectors) == 0 { + return "" + } + parts := strings.SplitN(v.Connectors[0], "|", 2) + if len(parts) == 0 { + return "" + } + return strings.TrimSpace(parts[0]) +} + +func validateCompileOptions(v *options.Validate) []shape.CompileOption { + var opts []shape.CompileOption + if v != nil && v.Strict { + opts = append(opts, shape.WithCompileStrict(true)) + } + opts = append(opts, shape.WithLinkedTypes(false)) + return opts +} + +func validateDiagnostics(sourceURL string, result *shape.PlanResult) error { + planned, ok := plan.ResultFrom(result) + if !ok || planned == nil { + return nil + } + var issues []string + for _, diag := range planned.Diagnostics { + if diag == nil || diag.Severity != "error" { + continue + } + issues = append(issues, diag.Error()) + } + if len(issues) == 0 { + return nil + } + return fmt.Errorf("validate %s: %s", sourceURL, strings.Join(issues, "; ")) +} + +func validatePlannedSQLAssets(ctx context.Context, s *Service, source *shape.Source, result *shape.PlanResult) error { + planned, ok := plan.ResultFrom(result) + if !ok || planned == nil { + return nil + } + assets := collectPlannedSQLAssets(source, planned) + for _, asset := range assets { + if _, err := s.fs.DownloadWithURL(ctx, asset); err != nil { + return fmt.Errorf("missing SQL asset %s: %w", asset, err) + } + } + return nil +} + +func collectPlannedSQLAssets(source *shape.Source, planned *plan.Result) []string { + seen := map[string]bool{} + var result []string + appendAsset := func(candidate string) { + raw := strings.TrimSpace(candidate) + if !isExplicitSourceAsset(source, raw) { + return + } + candidate = raw + if candidate == "" { + return + } + if strings.Contains(candidate, "://") { + candidate = url.Path(candidate) + } + if !filepath.IsAbs(candidate) { + baseDir := "" + if source != nil { + baseDir = source.BaseDir() + } + if baseDir != "" { + candidate = filepath.Join(baseDir, filepath.FromSlash(candidate)) + } + } + candidate = filepath.Clean(candidate) + if candidate == "." || seen[candidate] { + return + } + seen[candidate] = true + result = append(result, file.Scheme+"://"+filepath.ToSlash(candidate)) + } + for _, route := range planned.Components { + if route == nil { + continue + } + appendAsset(route.SourceURL) + appendAsset(route.SummaryURL) + } + for _, item := range planned.Views { + if item == nil { + continue + } + appendAsset(item.SQLURI) + appendAsset(item.SummaryURL) + } + sort.Strings(result) + return result +} + +func isExplicitSourceAsset(source *shape.Source, candidate string) bool { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return false + } + if source == nil || strings.TrimSpace(source.DQL) == "" { + return true + } + clean := filepath.ToSlash(candidate) + if strings.Contains(source.DQL, clean) { + return true + } + base := path.Base(clean) + return base != "" && strings.Contains(source.DQL, base) +} diff --git a/cmd/command/validate_test.go b/cmd/command/validate_test.go new file mode 100644 index 000000000..7f7bb020a --- /dev/null +++ b/cmd/command/validate_test.go @@ -0,0 +1,65 @@ +package command + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestCollectPlannedSQLAssets_AbsolutizesAndDedupes(t *testing.T) { + tempDir := t.TempDir() + source := &shape.Source{Path: filepath.Join(tempDir, "query.dql")} + planned := &plan.Result{ + Components: []*plan.ComponentRoute{ + {SourceURL: "foo/root.sql"}, + }, + Views: []*plan.View{ + {Name: "foo", SQLURI: "foo/root.sql"}, + {Name: "bar", SQLURI: "bar/detail.sql", SummaryURL: "bar/summary.sql"}, + }, + } + + assets := collectPlannedSQLAssets(source, planned) + + require.Len(t, assets, 3) + require.Contains(t, assets[0]+assets[1]+assets[2], filepath.ToSlash(filepath.Join(tempDir, "foo", "root.sql"))) + require.Contains(t, assets[0]+assets[1]+assets[2], filepath.ToSlash(filepath.Join(tempDir, "bar", "detail.sql"))) + require.Contains(t, assets[0]+assets[1]+assets[2], filepath.ToSlash(filepath.Join(tempDir, "bar", "summary.sql"))) +} + +func TestValidatePlannedSQLAssets_MissingPath(t *testing.T) { + tempDir := t.TempDir() + source := &shape.Source{Path: filepath.Join(tempDir, "query.dql")} + planned := &shape.PlanResult{Source: source, Plan: &plan.Result{ + Views: []*plan.View{ + {Name: "foo", SQLURI: "foo/missing.sql"}, + }, + }} + + svc := New() + err := validatePlannedSQLAssets(context.Background(), svc, source, planned) + require.Error(t, err) + require.Contains(t, err.Error(), "missing SQL asset") + require.Contains(t, err.Error(), "missing.sql") +} + +func TestValidate_PatchBasicOne(t *testing.T) { + projectDir, err := filepath.Abs(filepath.Join("..", "..", "e2e", "v1")) + require.NoError(t, err) + source := filepath.Join(projectDir, "dql", "dev", "events", "patch_basic_one.dql") + svc := New() + + err = svc.Validate(context.Background(), &options.Options{ + Validate: &options.Validate{ + Project: projectDir, + Source: []string{source}, + }, + }) + + require.NoError(t, err) +} diff --git a/cmd/option.go b/cmd/option.go index 94511419a..85072d8d8 100644 --- a/cmd/option.go +++ b/cmd/option.go @@ -41,6 +41,10 @@ type ( cache *view.Cache SubstituesURL []string `long:"substituesURL" description:"substitues URL, expands template before processing"` JobURL string `short:"z" long:"joburl" description:"job url"` + MCPPort int `long:"mcpPort" description:"enable MCP HTTP server on the specified port"` + MCPAuthURL string `long:"mcpAuthClient" description:"auth client url for MCP server"` + MCPIssuerURL string `long:"mcpIssuerURL" description:"issuer url for MCP server"` + MCPAuthMode string `long:"mcpAuth" description:"authorizer S - server authorizer, F fallback authorizer"` } Package struct { @@ -166,7 +170,10 @@ func (o *Options) BuildOption() *options.Options { } if o.ConfigURL != "" && repo == nil { - result.Run = &options.Run{ConfigURL: o.ConfigURL, JobURL: o.JobURL} + result.Run = &options.Run{ConfigURL: o.ConfigURL, JobURL: o.JobURL, MCPAuthURL: o.MCPAuthURL, MCPIssuerURL: o.MCPIssuerURL, MCPAuthMode: o.MCPAuthMode} + if o.MCPPort > 0 { + result.Run.MCPPort = &o.MCPPort + } } return result } diff --git a/cmd/options/options.go b/cmd/options/options.go index f8bd60d45..4790eb1e4 100644 --- a/cmd/options/options.go +++ b/cmd/options/options.go @@ -10,6 +10,8 @@ type Options struct { Plugin *Plugin `command:"plugin" description:"build custom datly rule plugin" ` Generate *Generate `command:"gen" description:"generate dql for put,patch or post operation" ` Translate *Translate `command:"translate" description:"translate dql into datly repository rule"` + Transcribe *Transcribe `command:"transcribe" description:"compile dql with shape pipeline and generate bootstrap artifacts"` + Validate *Validate `command:"validate" description:"validate DQL and referenced SQL assets with the shape pipeline"` Cache *CacheWarmup `command:"cache" description:"warmup cache"` Run *Run `command:"run" description:"start datly in standalone mode"` Mcp *Mcp `command:"mcp" description:"run mcp"` @@ -72,6 +74,12 @@ func (o *Options) Init(ctx context.Context) error { if o.Translate != nil { return o.Translate.Init(ctx) } + if o.Transcribe != nil { + return o.Transcribe.Init(ctx) + } + if o.Validate != nil { + return o.Validate.Init(ctx) + } if o.Run != nil { return o.Run.Init() } @@ -105,6 +113,10 @@ func NewOptions(args Arguments) *Options { ret.InitCmd = &Init{} case "dsql", "translate", "dql": ret.Translate = &Translate{} + case "transcribe": + ret.Transcribe = &Transcribe{} + case "validate": + ret.Validate = &Validate{} case "cache": ret.Cache = &CacheWarmup{} case "run": diff --git a/cmd/options/run.go b/cmd/options/run.go index fb810b3f7..19c1f991f 100644 --- a/cmd/options/run.go +++ b/cmd/options/run.go @@ -11,6 +11,10 @@ type Run struct { MaxJobs int `short:"W" long:"mjobs" description:"max jobs" default:"40" ` FailedJobURL string `short:"F" long:"fjobs" description:"failed jobs" ` LoadPlugin bool `short:"L" long:"lplugin" description:"load plugin"` + MCPPort *int `long:"mcpPort" description:"enable MCP HTTP server on the specified port"` + MCPAuthURL string `long:"mcpAuthClient" description:"auth client url for MCP server"` + MCPIssuerURL string `long:"mcpIssuerURL" description:"issuer url for MCP server"` + MCPAuthMode string `long:"mcpAuth" description:"authorizer S - server authorizer, F fallback authorizer" choice:"F" choice:"S"` PluginInfo string Version string } diff --git a/cmd/options/transcribe.go b/cmd/options/transcribe.go new file mode 100644 index 000000000..8748c17e5 --- /dev/null +++ b/cmd/options/transcribe.go @@ -0,0 +1,66 @@ +package options + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/viant/afs/url" +) + +// Transcribe defines options for the shape-only DQL -> code/config pipeline. +type Transcribe struct { + Connector + Auth + Source []string `short:"s" long:"src" description:"DQL source file(s)"` + Repository string `short:"r" long:"repo" description:"output repository location" default:"repo/dev"` + Namespace string `short:"u" long:"namespace" description:"route namespace" default:"dev"` + Module string `short:"m" long:"module" description:"go module location" default:"."` + Strict bool `long:"strict" description:"enable strict compile mode"` + TypeOutput string `long:"type-output" description:"go type output directory (default: same as --module)"` + TypeFile string `long:"type-file" description:"generated go file name (default: dql filename or main view in lower_underscore)"` + Project string `short:"p" long:"proj" description:"project location"` + APIPrefix string `short:"a" long:"api" description:"api prefix" default:"/v1/api"` + SkipYAML bool `long:"skip-yaml" description:"generate bootstrap config and shapes without route yaml"` +} + +func (t *Transcribe) DefaultConnectorName() string { + if len(t.Connectors) == 0 { + return "" + } + parts := strings.SplitN(t.Connectors[0], "|", 2) + if len(parts) > 0 { + return strings.TrimSpace(parts[0]) + } + return "" +} + +func (t *Transcribe) Init(ctx context.Context) error { + _ = ctx + if t.Project == "" { + t.Project, _ = os.Getwd() + } + t.Project = ensureAbsPath(t.Project) + t.Connector.Init() + t.Auth.Init() + if url.IsRelative(t.Repository) { + t.Repository = url.Join(t.Project, t.Repository) + } + if url.IsRelative(t.Module) { + t.Module = url.Join(t.Project, t.Module) + } + if t.TypeOutput != "" && url.IsRelative(t.TypeOutput) { + t.TypeOutput = url.Join(t.Project, t.TypeOutput) + } + if len(t.Source) == 0 { + return fmt.Errorf("transcribe: at least one --src is required") + } + for i := range t.Source { + expandRelativeIfNeeded(&t.Source[i], t.Project) + } + if strings.TrimSpace(t.Namespace) == "" { + t.Namespace = "dev" + } + return nil +} diff --git a/cmd/options/validate.go b/cmd/options/validate.go new file mode 100644 index 000000000..6d98850ea --- /dev/null +++ b/cmd/options/validate.go @@ -0,0 +1,35 @@ +package options + +import ( + "context" + "fmt" + "os" + + "github.com/viant/afs/url" +) + +// Validate defines options for shape-only DQL validation. +type Validate struct { + Connector + Source []string `short:"s" long:"src" description:"DQL source file(s)"` + Project string `short:"p" long:"proj" description:"project location"` + Strict bool `long:"strict" description:"enable strict compile mode"` +} + +func (v *Validate) Init(ctx context.Context) error { + _ = ctx + if v.Project == "" { + v.Project, _ = os.Getwd() + } + v.Project = ensureAbsPath(v.Project) + v.Connector.Init() + if len(v.Source) == 0 { + return fmt.Errorf("validate: at least one --src is required") + } + for i := range v.Source { + if url.IsRelative(v.Source[i]) { + expandRelativeIfNeeded(&v.Source[i], v.Project) + } + } + return nil +} diff --git a/e2e/local/dql/generate_patch_basic_many/patch_basic_many.sql b/e2e/local/dql/generate_patch_basic_many/patch_basic_many.sql new file mode 100644 index 000000000..d99b50691 --- /dev/null +++ b/e2e/local/dql/generate_patch_basic_many/patch_basic_many.sql @@ -0,0 +1,33 @@ +/* {"URI":"/v1/api/dev/basic/foos-many","Method":"PATCH","Connector":"dev"} */ + + +import ( + "generate_patch_basic_many.Foos" + ) + + +#set($_ = $Foos<[]Foos>(body/).WithTag('anonymous:"true"').Required()) + #set($_ = $CurFoosId(param/Foos) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurFoos<[]*Foos>(view/CurFoos) /* +? SELECT * FROM FOOS +WHERE $criteria.In("ID", $CurFoosId.Values) +*/ +) +#set($_ = $Foos<[]>(body/).WithTag('anonymous:"true" typeName:"Foos"').Required().Output()) + + + +$sequencer.Allocate("FOOS", $Foos, "Id") + +#set($CurFoosById = $CurFoos.IndexBy("Id")) + +#foreach($RecFoos in $Foos) + #if($CurFoosById.HasKey($RecFoos.Id) == true) +$sql.Update($RecFoos, "FOOS"); + #else +$sql.Insert($RecFoos, "FOOS"); + #end +#end \ No newline at end of file diff --git a/e2e/local/dql/generate_patch_basic_one/patch_basic_one.sql b/e2e/local/dql/generate_patch_basic_one/patch_basic_one.sql new file mode 100644 index 000000000..0a1bc8093 --- /dev/null +++ b/e2e/local/dql/generate_patch_basic_one/patch_basic_one.sql @@ -0,0 +1,33 @@ +/* {"URI":"/v1/api/dev/basic/foos","Method":"PATCH","Connector":"dev"} */ + + +import ( + "generate_patch_basic_one.Foos" + ) + + +#set($_ = $Foos(body/).WithTag('anonymous:"true"').Required()) + #set($_ = $CurFoosId(param/Foos) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurFoos<[]*Foos>(view/CurFoos) /* +? SELECT * FROM FOOS +WHERE $criteria.In("ID", $CurFoosId.Values) +*/ +) +#set($_ = $Foos<>(body/).WithTag('anonymous:"true" typeName:"Foos"').Required().Output()) + + + +$sequencer.Allocate("FOOS", $Foos, "Id") + +#set($CurFoosById = $CurFoos.IndexBy("Id")) + +#if($Foos) + #if($CurFoosById.HasKey($Foos.Id) == true) +$sql.Update($Foos, "FOOS"); + #else +$sql.Insert($Foos, "FOOS"); + #end +#end \ No newline at end of file diff --git a/e2e/local/dql/generate_patch_many_many/patch_basic_many_many.sql b/e2e/local/dql/generate_patch_many_many/patch_basic_many_many.sql new file mode 100644 index 000000000..c7c6b2588 --- /dev/null +++ b/e2e/local/dql/generate_patch_many_many/patch_basic_many_many.sql @@ -0,0 +1,55 @@ +/* {"URI":"/v1/api/dev/basic/foos-many-many","Method":"PATCH","Connector":"dev"} */ + + +import ( + "generate_patch_many_many.Foos" + "generate_patch_many_many.FoosPerformance" + ) + + +#set($_ = $Foos<[]Foos>(body/).WithTag('anonymous:"true"').Required()) + #set($_ = $CurFoosId(param/Foos) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurFoosFoosPerformanceId(param/Foos) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/FoosPerformance` LIMIT 1 +*/ +) + #set($_ = $CurFoosPerformance<[]*FoosPerformance>(view/CurFoosPerformance) /* +? SELECT * FROM FOOS_PERFORMANCE +WHERE $criteria.In("ID", $CurFoosFoosPerformanceId.Values) +*/ +) + #set($_ = $CurFoos<[]*Foos>(view/CurFoos) /* +? SELECT * FROM FOOS +WHERE $criteria.In("ID", $CurFoosId.Values) +*/ +) +#set($_ = $Foos<[]>(body/).WithTag('anonymous:"true" typeName:"Foos"').Required().Output()) + + + +$sequencer.Allocate("FOOS", $Foos, "Id") + +$sequencer.Allocate("FOOS_PERFORMANCE", $Foos, "FoosPerformance/Id") + +#set($CurFoosById = $CurFoos.IndexBy("Id")) +#set($CurFoosPerformanceById = $CurFoosPerformance.IndexBy("Id")) + +#foreach($RecFoos in $Foos) + #if($CurFoosById.HasKey($RecFoos.Id) == true) +$sql.Update($RecFoos, "FOOS"); + #else +$sql.Insert($RecFoos, "FOOS"); + #end + + #foreach($RecFoosPerformance in $RecFoos.FoosPerformance) + #set($RecFoosPerformance.FooId = $RecFoos.Id) + #if($CurFoosPerformanceById.HasKey($RecFoosPerformance.Id) == true) +$sql.Update($RecFoosPerformance, "FOOS_PERFORMANCE"); + #else +$sql.Insert($RecFoosPerformance, "FOOS_PERFORMANCE"); + #end + #end +#end \ No newline at end of file diff --git a/e2e/local/dql/generate_post_basic_many/post_basic_many.sql b/e2e/local/dql/generate_post_basic_many/post_basic_many.sql new file mode 100644 index 000000000..8e9013974 --- /dev/null +++ b/e2e/local/dql/generate_post_basic_many/post_basic_many.sql @@ -0,0 +1,28 @@ +/* {"URI":"/v1/api/dev/basic/events-many","Method":"POST","Connector":"dev"} */ + + +import ( + "generate_post_basic_many.Events" + ) + + +#set($_ = $Events<[]Events>(body/).WithTag('anonymous:"true"').Required()) + #set($_ = $CurEventsId(param/Events) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurEvents<[]*Events>(view/CurEvents) /* +? SELECT * FROM EVENTS +WHERE $criteria.In("ID", $CurEventsId.Values) +*/ +) +#set($_ = $Events<[]>(body/).WithTag('anonymous:"true" typeName:"Events"').Required().Output()) + + + +$sequencer.Allocate("EVENTS", $Events, "Id") + + +#foreach($RecEvents in $Events) +$sql.Insert($RecEvents, "EVENTS"); +#end \ No newline at end of file diff --git a/e2e/local/dql/generate_post_basic_one/post_basic_one.sql b/e2e/local/dql/generate_post_basic_one/post_basic_one.sql new file mode 100644 index 000000000..8e7680390 --- /dev/null +++ b/e2e/local/dql/generate_post_basic_one/post_basic_one.sql @@ -0,0 +1,28 @@ +/* {"URI":"/v1/api/dev/basic/events","Method":"POST","Connector":"dev"} */ + + +import ( + "generate_post_basic_one.Events" + ) + + +#set($_ = $Events(body/).WithTag('anonymous:"true"').Required()) + #set($_ = $CurEventsId(param/Events) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurEvents<*Events>(view/CurEvents) /* +? SELECT * FROM EVENTS +WHERE $criteria.In("ID", $CurEventsId.Values) +*/ +) +#set($_ = $Events<>(body/).WithTag('anonymous:"true" typeName:"Events"').Required().Output()) + + + +$sequencer.Allocate("EVENTS", $Events, "Id") + + +#if($Events) +$sql.Insert($Events, "EVENTS"); +#end \ No newline at end of file diff --git a/e2e/local/dql/generate_post_comprehensive_many/post_comprehensive_many.sql b/e2e/local/dql/generate_post_comprehensive_many/post_comprehensive_many.sql new file mode 100644 index 000000000..15dc98aac --- /dev/null +++ b/e2e/local/dql/generate_post_comprehensive_many/post_comprehensive_many.sql @@ -0,0 +1,29 @@ +/* {"URI":"/v1/api/dev/comprehensive/events-many","Method":"POST","Connector":"dev"} */ + + +import ( + "generate_post_comprehensive_many.Events" + ) + + +#set($_ = $Events<[]Events>(body/Data).Required()) + #set($_ = $CurEventsId(param/Events) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurEvents<[]*Events>(view/CurEvents) /* +? SELECT * FROM EVENTS +WHERE $criteria.In("ID", $CurEventsId.Values) +*/ +) +#set($_ = $Status<*>(output/status).WithTag('anonymous:"true"').Output()) + #set($_ = $Data<[]>(body/Data).WithTag(' typeName:"Events"').Required().Output()) + + + +$sequencer.Allocate("EVENTS", $Events, "Id") + + +#foreach($RecEvents in $Events) +$sql.Insert($RecEvents, "EVENTS"); +#end \ No newline at end of file diff --git a/e2e/local/dql/generate_post_except/post_except.sql b/e2e/local/dql/generate_post_except/post_except.sql new file mode 100644 index 000000000..d143860ef --- /dev/null +++ b/e2e/local/dql/generate_post_except/post_except.sql @@ -0,0 +1,28 @@ +/* {"URI":"/v1/api/dev/basic/events-except","Method":"POST","Connector":"dev"} */ + + +import ( + "generate_post_except.Events" + ) + + +#set($_ = $Events(body/).WithTag('anonymous:"true"').Required()) + #set($_ = $CurEventsId(param/Events) /* +? SELECT ARRAY_AGG(Id) AS Values FROM `/` LIMIT 1 +*/ +) + #set($_ = $CurEvents<*Events>(view/CurEvents) /* +? SELECT * FROM EVENTS +WHERE $criteria.In("ID", $CurEventsId.Values) +*/ +) +#set($_ = $Events<>(body/).WithTag('anonymous:"true" typeName:"Events"').Required().Output()) + + + +$sequencer.Allocate("EVENTS", $Events, "Id") + + +#if($Events) +$sql.Insert($Events, "EVENTS"); +#end \ No newline at end of file diff --git a/e2e/local/regression/app.yaml b/e2e/local/regression/app.yaml index 33036c290..0d29167db 100644 --- a/e2e/local/regression/app.yaml +++ b/e2e/local/regression/app.yaml @@ -14,7 +14,7 @@ pipeline: immuneToHangups: true env: TEST: 1 - command: ulimit -Sn 10000 && ./datly -c=$appPath/e2e/local/autogen/Datly/config.json -z=/tmp/jobs/datly > /tmp/datly.out + command: ulimit -Sn 10000 && ./datly -c=$appPath/e2e/local/autogen/Datly/config.json -z=/tmp/jobs/datly --mcpPort=8281 > /tmp/datly.out validator: stop: @@ -33,4 +33,4 @@ pipeline: TEST: 1 VALIDATOR_PORT: 8871 - command: ./validator_24 \ No newline at end of file + command: ./validator_24 diff --git a/e2e/local/regression/cases/010_codecs/expect.json b/e2e/local/regression/cases/010_codecs/expect.json deleted file mode 100644 index 52dba1e84..000000000 --- a/e2e/local/regression/cases/010_codecs/expect.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "id": 1, - "name": "Vendor 1", - "accountId": 100, - "userCreated": 1 - }, - { - "id": 2, - "name": "Vendor 2", - "accountId": 101, - "userCreated": 2 - } -] \ No newline at end of file diff --git a/e2e/local/regression/cases/010_codecs/gen.json b/e2e/local/regression/cases/010_codecs/gen.json deleted file mode 100644 index bb1701353..000000000 --- a/e2e/local/regression/cases/010_codecs/gen.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "Name": "$tagId", - "URL": "$path/vendors_codec.sql", - "Args": "" -} \ No newline at end of file diff --git a/e2e/local/regression/cases/010_codecs/vendors_codec.sql b/e2e/local/regression/cases/010_codecs/vendors_codec.sql deleted file mode 100644 index fcef47b03..000000000 --- a/e2e/local/regression/cases/010_codecs/vendors_codec.sql +++ /dev/null @@ -1,6 +0,0 @@ -/* {"URI":"vendors-codec/"} */ - -#set( $_ = $Data(output/view).Embed()) - -SELECT vendor.* -FROM (SELECT * FROM VENDOR t WHERE t.ID IN ($vendorIDs) ) vendor \ No newline at end of file diff --git a/e2e/local/regression/cases/010_grouping/expect_account_totals.json b/e2e/local/regression/cases/010_grouping/expect_account_totals.json new file mode 100644 index 000000000..20d77e4ad --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/expect_account_totals.json @@ -0,0 +1,12 @@ +[ + { + "accountId": 100, + "totalId": 4, + "maxId": 3 + }, + { + "accountId": 101, + "totalId": 2, + "maxId": 2 + } +] diff --git a/e2e/local/regression/cases/010_grouping/expect_account_user_totals.json b/e2e/local/regression/cases/010_grouping/expect_account_user_totals.json new file mode 100644 index 000000000..b0a135943 --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/expect_account_user_totals.json @@ -0,0 +1,14 @@ +[ + { + "accountId": 100, + "userCreated": 1, + "totalId": 4, + "maxId": 3 + }, + { + "accountId": 101, + "userCreated": 2, + "totalId": 2, + "maxId": 2 + } +] diff --git a/e2e/local/regression/cases/010_grouping/expect_empty.json b/e2e/local/regression/cases/010_grouping/expect_empty.json new file mode 100644 index 000000000..fe51488c7 --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/expect_empty.json @@ -0,0 +1 @@ +[] diff --git a/e2e/local/regression/cases/010_grouping/expect_totals.json b/e2e/local/regression/cases/010_grouping/expect_totals.json new file mode 100644 index 000000000..24383af3f --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/expect_totals.json @@ -0,0 +1,6 @@ +[ + { + "totalId": 6, + "maxId": 3 + } +] diff --git a/e2e/local/regression/cases/010_grouping/expect_user_totals.json b/e2e/local/regression/cases/010_grouping/expect_user_totals.json new file mode 100644 index 000000000..5473ea0f4 --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/expect_user_totals.json @@ -0,0 +1,12 @@ +[ + { + "userCreated": 1, + "totalId": 4, + "maxId": 3 + }, + { + "userCreated": 2, + "totalId": 2, + "maxId": 2 + } +] diff --git a/e2e/local/regression/cases/010_grouping/gen.json b/e2e/local/regression/cases/010_grouping/gen.json new file mode 100644 index 000000000..007920478 --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/gen.json @@ -0,0 +1,5 @@ +{ + "Name": "$tagId", + "URL": "$path/vendors_grouping.sql", + "Args": "" +} diff --git a/e2e/local/regression/cases/010_grouping/test.yaml b/e2e/local/regression/cases/010_grouping/test.yaml new file mode 100644 index 000000000..222546276 --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/test.yaml @@ -0,0 +1,186 @@ +init: + parentPath: $parent.path + +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping?vendorIDs=1,2,3&_fields=accountId,totalId,maxId&_orderby=accountId + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_account_totals.json') + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping?vendorIDs=1,2,3&_fields=accountId,userCreated,totalId,maxId&_orderby=accountId + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_account_user_totals.json') + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping?vendorIDs=1,2,3&_fields=userCreated,totalId,maxId&_orderby=userCreated + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_user_totals.json') + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping?vendorIDs=1,2,3&_fields=totalId,maxId + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_totals.json') + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping?vendorIDs=1,2,3&_fields=accountId,totalId,maxId&_orderby=accountId&_limit=1&_offset=2 + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_empty.json') + + - Method: POST + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping/report + JSONBody: + dimensions: + accountId: true + measures: + totalId: true + maxId: true + filters: + vendorIDs: + - 1 + - 2 + - 3 + orderBy: + - accountId + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_account_totals.json') + + - Method: POST + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping/report + JSONBody: + dimensions: + accountId: true + userCreated: true + measures: + totalId: true + maxId: true + filters: + vendorIDs: + - 1 + - 2 + - 3 + orderBy: + - accountId + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_account_user_totals.json') + + - Method: POST + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping/report + JSONBody: + dimensions: {} + measures: + totalId: true + maxId: true + filters: + vendorIDs: + - 1 + - 2 + - 3 + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_totals.json') + + - Method: POST + URL: http://127.0.0.1:8080/v1/api/dev/vendors-grouping/report + JSONBody: + dimensions: + accountId: true + measures: + totalId: true + maxId: true + filters: + vendorIDs: + - 1 + - 2 + - 3 + orderBy: + - accountId + limit: 1 + offset: 2 + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_empty.json') + + mcpInitialize: + action: http/runner:send + requests: + - Method: POST + URL: http://127.0.0.1:8281/mcp + Header: + Content-Type: + - application/json + JSONBody: + jsonrpc: "2.0" + id: 1 + method: initialize + params: + protocolVersion: "2025-06-18" + capabilities: {} + clientInfo: + name: endly + version: "1.0" + Expect: + Code: 200 + post: + protocolVersion: ${Responses[0].JSONBody.result.protocolVersion} + + mcpListTools: + action: http/runner:send + requests: + - Method: POST + URL: http://127.0.0.1:8281/mcp + Header: + Content-Type: + - application/json + MCP-Protocol-Version: + - ${protocolVersion} + JSONBody: + jsonrpc: "2.0" + id: 2 + method: tools/list + params: {} + Expect: + Code: 200 + + mcpCallReport: + action: http/runner:send + requests: + - Method: POST + URL: http://127.0.0.1:8281/mcp + Header: + Content-Type: + - application/json + MCP-Protocol-Version: + - ${protocolVersion} + JSONBody: + jsonrpc: "2.0" + id: 3 + method: tools/call + params: + name: vendorsgroupingReport + arguments: + dimensions: + accountId: true + measures: + totalId: true + maxId: true + filters: + vendorIDs: + - 1 + - 2 + - 3 + orderBy: + - accountId + Expect: + Code: 200 diff --git a/e2e/local/regression/cases/010_grouping/vendors_grouping.sql b/e2e/local/regression/cases/010_grouping/vendors_grouping.sql new file mode 100644 index 000000000..9734a860a --- /dev/null +++ b/e2e/local/regression/cases/010_grouping/vendors_grouping.sql @@ -0,0 +1,18 @@ +/* {"URI":"vendors-grouping/","Name":"vendors grouping","MCPTool":true} */ + +#set( $_ = $report()) +#set( $_ = $Data(output/view).Embed()) +#set( $_ = $VendorIDs<[]int>(query/vendorIDs).WithPredicate(0, 'in', 't', 'ID')) + +SELECT vendor.*, + grouping_enabled(vendor), + allowed_order_by_columns(vendor, 'accountId:ACCOUNT_ID,userCreated:USER_CREATED,totalId:TOTAL_ID,maxId:MAX_ID') +FROM ( + SELECT ACCOUNT_ID, + USER_CREATED, + SUM(ID) AS TOTAL_ID, + MAX(ID) AS MAX_ID + FROM VENDOR t + WHERE t.ID IN ($VendorIDs) + GROUP BY 1, 2 +) vendor diff --git a/e2e/local/regression/regression.yaml b/e2e/local/regression/regression.yaml index a8b575fae..e4e47e667 100644 --- a/e2e/local/regression/regression.yaml +++ b/e2e/local/regression/regression.yaml @@ -30,7 +30,7 @@ pipeline: '[]gen': '@gen' subPath: 'cases/${index}_*' - range: 11..020 + range: 1..015 template: checkSkip: action: nop diff --git a/e2e/v1/build.yaml b/e2e/v1/build.yaml new file mode 100644 index 000000000..7b3fb2f87 --- /dev/null +++ b/e2e/v1/build.yaml @@ -0,0 +1,25 @@ +pipeline: + deploy: + setPath: + action: exec:run + target: $target + checkError: true + commands: + - export GOPATH=${env.GOPATH} + - export PATH=/usr/local/go/bin:$PATH + + set_sdk: + action: sdk.set + target: $target + sdk: go:1.25.5 + + package: + action: exec:run + comments: build plain datly binary for pure DQL bootstrap tests + target: $target + checkError: true + commands: + - export GO111MODULE=on + - export GOFLAGS=-mod=mod + - cd ${appPath} + - go build -ldflags "-X main.BuildTimeInS=`date +%s`" -o /tmp/datly ./cmd/datly diff --git a/e2e/v1/cases/001_relation_one_to_many/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/001_relation_one_to_many/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/001_relation_one_to_many/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/001_relation_one_to_many/dbsetup/dev/VENDOR.json b/e2e/v1/cases/001_relation_one_to_many/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/001_relation_one_to_many/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/001_relation_one_to_many/expect.json b/e2e/v1/cases/001_relation_one_to_many/expect.json new file mode 100644 index 000000000..96437d11c --- /dev/null +++ b/e2e/v1/cases/001_relation_one_to_many/expect.json @@ -0,0 +1,78 @@ +[ + { + "id": 1, + "name": "Vendor 1", + "accountId": 100, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null, + "products": [ + { + "id": 1, + "name": "V1 Product 1", + "status": 1, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null + }, + { + "id": 2, + "name": "V1 Product 2", + "status": 1, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null + } + ] + }, + { + "id": 2, + "name": "Vendor 2", + "accountId": 101, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null, + "products": [ + { + "id": 3, + "name": "V2 Product 1", + "status": 1, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null + }, + { + "id": 4, + "name": "V2 Product 2", + "status": 1, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null + }, + { + "id": 5, + "name": "V2 Product 3", + "status": 1, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null + } + ] + }, + { + "id": 3, + "name": "Vendor 3", + "accountId": 100, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null + } +] diff --git a/e2e/v1/cases/001_relation_one_to_many/expect_2.txt b/e2e/v1/cases/001_relation_one_to_many/expect_2.txt new file mode 100644 index 000000000..49c4027be --- /dev/null +++ b/e2e/v1/cases/001_relation_one_to_many/expect_2.txt @@ -0,0 +1,27 @@ +package generated + +import ( + "time" +) + +type GeneratedStruct struct { + Id int `sqlx:"ID"` + Name *string `sqlx:"NAME"` + AccountId *int `sqlx:"ACCOUNT_ID"` + Created *time.Time `sqlx:"CREATED"` + UserCreated *int `sqlx:"USER_CREATED"` + Updated *time.Time `sqlx:"UPDATED"` + UserUpdated *int `sqlx:"USER_UPDATED"` + Products []*Products `view:",table=PRODUCT,connector=dev,selectorNamespace=pr"` +} + +type Products struct { + Id int `sqlx:"ID"` + Name *string `sqlx:"NAME"` + VendorId *int `sqlx:"VENDOR_ID" internal:"true"` + Status *int `sqlx:"STATUS"` + Created *time.Time `sqlx:"CREATED"` + UserCreated *int `sqlx:"USER_CREATED"` + Updated *time.Time `sqlx:"UPDATED"` + UserUpdated *int `sqlx:"USER_UPDATED"` +} diff --git a/e2e/v1/cases/001_relation_one_to_many/test.yaml b/e2e/v1/cases/001_relation_one_to_many/test.yaml new file mode 100644 index 000000000..2f186d10c --- /dev/null +++ b/e2e/v1/cases/001_relation_one_to_many/test.yaml @@ -0,0 +1,21 @@ +init: + parentPath: $parent.path + created: $FormatTime('nowInUTC', 'yyyy-MM-ddT00:00:00Z') +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/vendors/ + Expect: + Code: 200 + JSONBody: $LoadData('${parentPath}/expect.json') + + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/meta/struct/dev/vendors/ + Expect: + Code: 200 + Body: $Cat('${parentPath}/expect_2.txt') +#/v1/api/shape/dev/vendors?yy=id,vendorId&xx=id,name,products diff --git a/e2e/v1/cases/002_relation_self_ref_tree/dbsetup/dev/USER.json b/e2e/v1/cases/002_relation_self_ref_tree/dbsetup/dev/USER.json new file mode 100644 index 000000000..ee3de9dab --- /dev/null +++ b/e2e/v1/cases/002_relation_self_ref_tree/dbsetup/dev/USER.json @@ -0,0 +1,37 @@ +[ + {}, + { + "ID": 1, + "NAME": "User 1", + "ACCOUNT_ID": 100 + }, + { + "ID": 2, + "NAME": "User 2", + "ACCOUNT_ID": 101 + }, + { + "ID": 3, + "NAME": "User 3", + "ACCOUNT_ID": 100, + "MGR_ID": 1 + }, + { + "ID": 4, + "NAME": "User 2", + "ACCOUNT_ID": 101, + "MGR_ID": 1 + }, + { + "ID": 5, + "NAME": "User 1", + "ACCOUNT_ID": 100, + "MGR_ID": 3 + }, + { + "ID": 6, + "NAME": "User 2", + "ACCOUNT_ID": 101, + "MGR_ID": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/002_relation_self_ref_tree/expect.json b/e2e/v1/cases/002_relation_self_ref_tree/expect.json new file mode 100644 index 000000000..d75a802c0 --- /dev/null +++ b/e2e/v1/cases/002_relation_self_ref_tree/expect.json @@ -0,0 +1,44 @@ +{ + "data": [ + { + "accountId": 100, + "id": 1, + "name": "User 1", + "team": [ + { + "accountId": 100, + "id": 3, + "name": "User 3", + "team": [ + { + "accountId": 100, + "id": 5, + "name": "User 1", + "team": [] + } + ] + }, + { + "accountId": 101, + "id": 4, + "name": "User 2", + "team": [] + } + ] + }, + { + "accountId": 101, + "id": 2, + "name": "User 2", + "team": [ + { + "accountId": 101, + "id": 6, + "name": "User 2", + "team": [] + } + ] + } + ], + "status": "ok" +} diff --git a/e2e/v1/cases/002_relation_self_ref_tree/test.yaml b/e2e/v1/cases/002_relation_self_ref_tree/test.yaml new file mode 100644 index 000000000..5579809c6 --- /dev/null +++ b/e2e/v1/cases/002_relation_self_ref_tree/test.yaml @@ -0,0 +1,14 @@ +init: + parentPath: $parent.path + expect: $LoadJSON('${parentPath}/expect.json') + +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/users/ + Expect: + Code: 200 + JSONBody: $expect diff --git a/e2e/v1/cases/003_relation_parent_join_optimization/test.yaml b/e2e/v1/cases/003_relation_parent_join_optimization/test.yaml new file mode 100644 index 000000000..0c45ed97b --- /dev/null +++ b/e2e/v1/cases/003_relation_parent_join_optimization/test.yaml @@ -0,0 +1,9 @@ +pipeline: + test: + description: parent join-on optimization is expanded inside child UNION branches + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/col/vendors/ + Expect: + Code: 200 diff --git a/e2e/v1/cases/004_relation_one_to_one/test.yaml b/e2e/v1/cases/004_relation_one_to_one/test.yaml new file mode 100644 index 000000000..cecbb819e --- /dev/null +++ b/e2e/v1/cases/004_relation_one_to_one/test.yaml @@ -0,0 +1,9 @@ +pipeline: + test: + action: http/runner:send + requests: + - Method: GET + description: one-to-one relation via JOIN ... AND 1=1 hint + URL: http://127.0.0.1:8080/v1/api/shape/dev/basic/events-one-one + Expect: + Code: 200 diff --git a/e2e/v1/cases/005_kind_uri_param/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/005_kind_uri_param/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/005_kind_uri_param/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/005_kind_uri_param/dbsetup/dev/VENDOR.json b/e2e/v1/cases/005_kind_uri_param/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/005_kind_uri_param/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/005_kind_uri_param/expect.json b/e2e/v1/cases/005_kind_uri_param/expect.json new file mode 100644 index 000000000..c286edbca --- /dev/null +++ b/e2e/v1/cases/005_kind_uri_param/expect.json @@ -0,0 +1,48 @@ +[ + { + "setting": [ + { + "channel": 3, + "isActive": 1 + } + ], + "vendor": { + "accountId": 101, + "created": "${created}", + "id": 2, + "name": "Vendor 2", + "products": [ + { + "created": "${created}", + "id": 3, + "name": "V2 Product 1", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "created": "${created}", + "id": 4, + "name": "V2 Product 2", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "created": "${created}", + "id": 5, + "name": "V2 Product 3", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + } + ], + "updated": null, + "userCreated": 2, + "userUpdated": null + } + } +] diff --git a/e2e/v1/cases/005_kind_uri_param/test.yaml b/e2e/v1/cases/005_kind_uri_param/test.yaml new file mode 100644 index 000000000..dda6ddff7 --- /dev/null +++ b/e2e/v1/cases/005_kind_uri_param/test.yaml @@ -0,0 +1,14 @@ +init: + parentPath: $parent.path + created: $FormatTime('nowInUTC', 'yyyy-MM-ddT00:00:00Z') + expect: $LoadData('${parentPath}/expect.json') +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/vendors/2 + Expect: + Code: 200 + JSONBody: $expect diff --git a/e2e/v1/cases/006_kind_header_params/expect.json b/e2e/v1/cases/006_kind_header_params/expect.json new file mode 100644 index 000000000..34725db23 --- /dev/null +++ b/e2e/v1/cases/006_kind_header_params/expect.json @@ -0,0 +1,40 @@ +[ + { + "accountId": 101, + "created": "${created}", + "id": 2, + "name": "Vendor 2", + "products": [ + { + "created": "${created}", + "id": 3, + "name": "V2 Product 1", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "created": "${created}", + "id": 4, + "name": "V2 Product 2", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "created": "${created}", + "id": 5, + "name": "V2 Product 3", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + } + ], + "updated": null, + "userCreated": 2, + "userUpdated": null + } +] diff --git a/e2e/v1/cases/006_kind_header_params/test.yaml b/e2e/v1/cases/006_kind_header_params/test.yaml new file mode 100644 index 000000000..5cee8ec46 --- /dev/null +++ b/e2e/v1/cases/006_kind_header_params/test.yaml @@ -0,0 +1,16 @@ +init: + parentPath: $parent.path + created: $FormatTime('nowInUTC', 'yyyy-MM-ddT00:00:00Z') + expect: $LoadData('${parentPath}/expect.json') +pipeline: + test: + description: inline SQL header hint binds Vendor-Id into the request parameter + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/headers/vendors + Header: + Vendor-Id: ["2"] + Expect: + Code: 200 + JSONBody: $expect diff --git a/e2e/v1/cases/007_kind_const/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/007_kind_const/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/007_kind_const/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/007_kind_const/dbsetup/dev/VENDOR.json b/e2e/v1/cases/007_kind_const/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/007_kind_const/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/007_kind_const/expect.json b/e2e/v1/cases/007_kind_const/expect.json new file mode 100644 index 000000000..b7f47634e --- /dev/null +++ b/e2e/v1/cases/007_kind_const/expect.json @@ -0,0 +1,74 @@ +[ + { + "accountId": 100, + "created": "@exists@", + "id": 1, + "name": "Vendor 1", + "products": [ + { + "created": "@exists@", + "status": 1, + "id": 1, + "name": "V1 Product 1", + "updated": null, + "userCreated": 1, + "userUpdated": null, + "vendorId": 1 + }, + { + "created": "@exists@", + "id": 2, + "name": "V1 Product 2", + "status": 1, + "updated": null, + "userCreated": 1, + "userUpdated": null, + "vendorId": 1 + } + ], + "updated": null, + "userCreated": 1, + "userUpdated": null + }, + { + "accountId": 101, + "created": "@exists@", + "id": 2, + "name": "Vendor 2", + "products": [ + { + "created": "@exists@", + "id": 3, + "name": "V2 Product 1", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null, + "vendorId": 2 + }, + { + "created": "@exists@", + "id": 4, + "name": "V2 Product 2", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null, + "vendorId": 2 + }, + { + "created": "@exists@", + "id": 5, + "name": "V2 Product 3", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null, + "vendorId": 2 + } + ], + "updated": null, + "userCreated": 2, + "userUpdated": null + } +] diff --git a/e2e/v1/cases/007_kind_const/test.yaml b/e2e/v1/cases/007_kind_const/test.yaml new file mode 100644 index 000000000..f476e3ed6 --- /dev/null +++ b/e2e/v1/cases/007_kind_const/test.yaml @@ -0,0 +1,22 @@ +init: + parentPath: $parent.path + created: $FormatTime('nowInUTC', 'yyyy-MM-ddT00:00:00Z') + expect: $LoadData('${parentPath}/expect.json') + +pipeline: + printHello: + action: print + message: hello action 1 + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/vendors-env?vendorIDs=1,2 + Expect: + Code: 200 + JSONBody: $expect + + info: + action: print + message: $AsJSON($test) diff --git a/e2e/v1/cases/008_summary_root/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/008_summary_root/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/008_summary_root/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/008_summary_root/dbsetup/dev/VENDOR.json b/e2e/v1/cases/008_summary_root/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/008_summary_root/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/008_summary_root/test.yaml b/e2e/v1/cases/008_summary_root/test.yaml new file mode 100644 index 000000000..75c61afe2 --- /dev/null +++ b/e2e/v1/cases/008_summary_root/test.yaml @@ -0,0 +1,15 @@ + +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/meta/vendors + Expect: + Code: 200 + JSONBody: + status: ok + meta: + pageCnt: 1 + cnt: 3 \ No newline at end of file diff --git a/e2e/v1/cases/009_summary_child/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/009_summary_child/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/009_summary_child/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/009_summary_child/dbsetup/dev/VENDOR.json b/e2e/v1/cases/009_summary_child/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/009_summary_child/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/009_summary_child/expect.json b/e2e/v1/cases/009_summary_child/expect.json new file mode 100644 index 000000000..58f9a04d8 --- /dev/null +++ b/e2e/v1/cases/009_summary_child/expect.json @@ -0,0 +1,94 @@ +{ + "data": [ + { + "accountId": 100, + "created": "@exists@", + "id": 1, + "name": "Vendor 1", + "products": [ + { + "created": "@exists@", + "id": 1, + "name": "V1 Product 1", + "updated": null, + "userCreated": 1, + "userUpdated": null + }, + { + "created": "@exists@", + "id": 2, + "name": "V1 Product 2", + "status": 1, + "updated": null, + "userCreated": 1, + "userUpdated": null + } + ], + "productsMeta": { + "pageCnt": 1, + "totalProducts": 2 + }, + "updated": null, + "userCreated": 1, + "userUpdated": null + }, + { + "accountId": 101, + "created": "@exists@", + "id": 2, + "name": "Vendor 2", + "products": [ + { + "created": "@exists@", + "id": 3, + "name": "V2 Product 1", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "created": "@exists@", + "id": 4, + "name": "V2 Product 2", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "created": "@exists@", + "id": 5, + "name": "V2 Product 3", + "status": 1, + "updated": null, + "userCreated": 2, + "userUpdated": null + } + ], + "productsMeta": { + "pageCnt": 1, + "totalProducts": 3 + }, + "updated": null, + "userCreated": 2, + "userUpdated": null + }, + { + "accountId": 100, + "created": "@exists@", + "id": 3, + "name": "Vendor 3", + "products": [], + "productsMeta": null, + "updated": null, + "userCreated": 1, + "userUpdated": null + } + ], + "meta": { + "cnt": 3, + "pageCnt": 1 + }, + "status": "ok" +} diff --git a/e2e/local/regression/cases/010_codecs/test.yaml b/e2e/v1/cases/009_summary_child/test.yaml similarity index 75% rename from e2e/local/regression/cases/010_codecs/test.yaml rename to e2e/v1/cases/009_summary_child/test.yaml index 75d28dd2e..8d5a58c70 100644 --- a/e2e/local/regression/cases/010_codecs/test.yaml +++ b/e2e/v1/cases/009_summary_child/test.yaml @@ -8,7 +8,7 @@ pipeline: action: http/runner:send requests: - Method: GET - URL: http://127.0.0.1:8080/v1/api/dev/vendors-codec?vendorIDs=1,2 + URL: http://127.0.0.1:8080/v1/api/shape/dev/meta/vendors-nested Expect: Code: 200 JSONBody: $expect \ No newline at end of file diff --git a/e2e/v1/cases/010_summary_multi/expect.json b/e2e/v1/cases/010_summary_multi/expect.json new file mode 100644 index 000000000..4488c62db --- /dev/null +++ b/e2e/v1/cases/010_summary_multi/expect.json @@ -0,0 +1,95 @@ +{ + "meta": { + "pageCnt": 1, + "cnt": 3 + }, + "data": [ + { + "id": 1, + "name": "Vendor 1", + "accountId": 100, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null, + "products": [ + { + "id": 1, + "name": "V1 Product 1", + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null + }, + { + "id": 2, + "name": "V1 Product 2", + "status": 1, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null + } + ], + "productsMeta": { + "vendorId": 1, + "pageCnt": 1, + "totalProducts": 2 + } + }, + { + "id": 2, + "name": "Vendor 2", + "accountId": 101, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null, + "products": [ + { + "id": 3, + "name": "V2 Product 1", + "status": 1, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null + }, + { + "id": 4, + "name": "V2 Product 2", + "status": 1, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null + }, + { + "id": 5, + "name": "V2 Product 3", + "status": 1, + "created": "@exists@", + "userCreated": 2, + "updated": null, + "userUpdated": null + } + ], + "productsMeta": { + "vendorId": 2, + "pageCnt": 1, + "totalProducts": 3 + } + }, + { + "id": 3, + "name": "Vendor 3", + "accountId": 100, + "created": "@exists@", + "userCreated": 1, + "updated": null, + "userUpdated": null, + "products": [] + } + ], + "status": "ok" +} diff --git a/e2e/v1/cases/010_summary_multi/test.yaml b/e2e/v1/cases/010_summary_multi/test.yaml new file mode 100644 index 000000000..a4836c0b8 --- /dev/null +++ b/e2e/v1/cases/010_summary_multi/test.yaml @@ -0,0 +1,13 @@ +init: + parentPath: $parent.path + expect: $LoadJSON('${parentPath}/expect.json') +pipeline: + test: + description: multiple summary views can be joined off the same root projection + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/meta/vendors-format + Expect: + Code: 200 + JSONBody: $expect diff --git a/e2e/v1/cases/011_summary_pagination/dbsetup/dev/CITY.json b/e2e/v1/cases/011_summary_pagination/dbsetup/dev/CITY.json new file mode 100644 index 000000000..2f8364176 --- /dev/null +++ b/e2e/v1/cases/011_summary_pagination/dbsetup/dev/CITY.json @@ -0,0 +1,27 @@ +[ + {}, + { + "ID": 1, + "NAME": "district - 1 / city - 1", + "ZIP_CODE": "12-345", + "DISTRICT_ID": 1 + }, + { + "ID": 2, + "NAME": "district - 2 / city - 1", + "DISTRICT_ID": 2, + "ZIP_CODE": "23-456" + }, + { + "ID": 3, + "NAME": "district - 1 / city - 2", + "DISTRICT_ID": 1, + "ZIP_CODE": "34-567" + }, + { + "ID": 4, + "NAME": "district - 1 / city - 3", + "DISTRICT_ID": 1, + "ZIP_CODE": "45_678" + } +] \ No newline at end of file diff --git a/e2e/v1/cases/011_summary_pagination/dbsetup/dev/DISTRICT.json b/e2e/v1/cases/011_summary_pagination/dbsetup/dev/DISTRICT.json new file mode 100644 index 000000000..3bed6ed69 --- /dev/null +++ b/e2e/v1/cases/011_summary_pagination/dbsetup/dev/DISTRICT.json @@ -0,0 +1,15 @@ +[ + {}, + { + "ID": 1, + "NAME": "district - 1" + }, + { + "ID": 2, + "NAME": "district - 2" + }, + { + "ID": 3, + "NAME": "district - 3" + } +] \ No newline at end of file diff --git a/e2e/v1/cases/011_summary_pagination/expect.json b/e2e/v1/cases/011_summary_pagination/expect.json new file mode 100644 index 000000000..b7571133a --- /dev/null +++ b/e2e/v1/cases/011_summary_pagination/expect.json @@ -0,0 +1,32 @@ +[ + { + "id": 1, + "name": "district - 1", + "cities": [ + { + "id": 1, + "name": "district - 1 / city - 1", + "zipCode": "12-345", + "districtId": 1 + }, + { + "id": 3, + "name": "district - 1 / city - 2", + "zipCode": "34-567", + "districtId": 1 + } + ] + }, + { + "id": 2, + "name": "district - 2", + "cities": [ + { + "id": 2, + "name": "district - 2 / city - 1", + "zipCode": "23-456", + "districtId": 2 + } + ] + } +] \ No newline at end of file diff --git a/e2e/v1/cases/011_summary_pagination/test.yaml b/e2e/v1/cases/011_summary_pagination/test.yaml new file mode 100644 index 000000000..8dd244aa9 --- /dev/null +++ b/e2e/v1/cases/011_summary_pagination/test.yaml @@ -0,0 +1,14 @@ +init: + parentPath: $parent.path + expect: $LoadJSON('${parentPath}/expect.json') + +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/meta/districts?IDs=1,2 + Expect: + Code: 200 + JSONBody: $expect \ No newline at end of file diff --git a/e2e/v1/cases/012_auth_oauth/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/012_auth_oauth/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/012_auth_oauth/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/012_auth_oauth/dbsetup/dev/VENDOR.json b/e2e/v1/cases/012_auth_oauth/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/012_auth_oauth/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/012_auth_oauth/expect.json b/e2e/v1/cases/012_auth_oauth/expect.json new file mode 100644 index 000000000..854079f16 --- /dev/null +++ b/e2e/v1/cases/012_auth_oauth/expect.json @@ -0,0 +1,27 @@ +[ + { + "id": 2, + "name": "Vendor 2", + "firstName": "Developer", + "products": [ + { + "@indexBy@": "id" + }, + { + "id": 3, + "name": "V2 Product 1", + "userCreated": 2 + }, + { + "id": 4, + "name": "V2 Product 2", + "userCreated": 2 + }, + { + "id": 5, + "name": "V2 Product 3", + "userCreated": 2 + } + ] + } +] \ No newline at end of file diff --git a/e2e/v1/cases/012_auth_oauth/test.yaml b/e2e/v1/cases/012_auth_oauth/test.yaml new file mode 100644 index 000000000..b78409a9c --- /dev/null +++ b/e2e/v1/cases/012_auth_oauth/test.yaml @@ -0,0 +1,53 @@ +init: + parentPath: $parent.path + expect: $LoadData('${parentPath}/expect.json') +pipeline: + + + signJWT: + action: secret:signJWT + privateKey: + URL: ${appPath}/e2e/cloud/jwt/private.enc + Key: blowfish://default + claims: + userID: 2 + firstName: Developer + email: dev@viantint.com + + printToken: + action: print + message: Bearer ${signJWT.TokenString} + + + test: +# testNoAuthenticated: +# action: http/runner:send +# requests: +# - Method: GET +# description: user is no authenticated +# URL: http://127.0.0.1:8080/v1/api/shape/dev/auth/vendors/2 +# Expect: +# Code: 401 +# +# testAuthenticatedAndAuthorized: +# action: http/runner:send +# requests: +# - Method: GET +# description: user is authenticated and authorized for vendor 2 +# URL: http://127.0.0.1:8080/v1/api/shape/dev/auth/vendors/2 +# Header: +# Authorization: Bearer ${signJWT.TokenString} +# Expect: +# Code: 200 +# JSONBody: $expect + + testAuthenticatedAndNoAuthorized: + action: http/runner:send + requests: + - Method: GET + description: user is authenticated but not authorized for vendor 1 (no data returned) + URL: http://127.0.0.1:8080/v1/api/shape/dev/auth/vendors/1 + Header: + Authorization: Bearer ${signJWT.TokenString} + Expect: + Code: 403 diff --git a/e2e/v1/cases/013_kind_mysql_boolean/expect.json b/e2e/v1/cases/013_kind_mysql_boolean/expect.json new file mode 100644 index 000000000..b988773ac --- /dev/null +++ b/e2e/v1/cases/013_kind_mysql_boolean/expect.json @@ -0,0 +1,17 @@ +[ + { + "@indexBy@": "id" + }, + { + "id": 1, + "userId": 1, + "isEnabled": true, + "isActivated": false + }, + { + "id": 2, + "userId": 2, + "isEnabled": false, + "isActivated": true + } +] diff --git a/e2e/v1/cases/013_kind_mysql_boolean/test.yaml b/e2e/v1/cases/013_kind_mysql_boolean/test.yaml new file mode 100644 index 000000000..b727b3e19 --- /dev/null +++ b/e2e/v1/cases/013_kind_mysql_boolean/test.yaml @@ -0,0 +1,10 @@ +pipeline: + test: + description: MySQL boolean columns are projected with the expected output shape + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/user-metadata + Expect: + Code: 200 + JSONBody: $LoadJSON('${parent.path}/expect.json') diff --git a/e2e/v1/cases/014_cache_sql_apikey/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/014_cache_sql_apikey/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/014_cache_sql_apikey/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/014_cache_sql_apikey/dbsetup/dev/VENDOR.json b/e2e/v1/cases/014_cache_sql_apikey/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/014_cache_sql_apikey/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/014_cache_sql_apikey/expect.json b/e2e/v1/cases/014_cache_sql_apikey/expect.json new file mode 100644 index 000000000..b963eee2e --- /dev/null +++ b/e2e/v1/cases/014_cache_sql_apikey/expect.json @@ -0,0 +1,26 @@ +[ + { + "id": 2, + "name": "Vendor 2", + "products": [ + { + "@indexBy@": "id" + }, + { + "id": 3, + "name": "V2 Product 1", + "userCreated": 2 + }, + { + "id": 4, + "name": "V2 Product 2", + "userCreated": 2 + }, + { + "id": 5, + "name": "V2 Product 3", + "userCreated": 2 + } + ] + } +] \ No newline at end of file diff --git a/e2e/v1/cases/014_cache_sql_apikey/test.yaml b/e2e/v1/cases/014_cache_sql_apikey/test.yaml new file mode 100644 index 000000000..a9ce0de9a --- /dev/null +++ b/e2e/v1/cases/014_cache_sql_apikey/test.yaml @@ -0,0 +1,46 @@ +init: + parentPath: $parent.path + expect: $LoadData('${parentPath}/expect.json') +pipeline: + + test: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/secured/vendors/2 + Expect: + Code: 403 + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/secured/vendors/2 + Header: + App-Secret-Id: 'changeme' + Expect: + Code: 200 + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/meta/view/dev/secured/vendors/2 + Expect: + Code: 403 + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/meta/view/dev/secured/vendors/2 + Header: + App-Secret-Id: 'changeme' + Expect: + Code: 200 + + test2: + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/meta/openapi/dev/secured/vendors/2 + Expect: + Code: 403 + + - Method: GET + URL: http://127.0.0.1:8080/v1/api/meta/openapi/dev/secured/vendors/2 + Header: + App-Secret-Id: 'changeme' + Expect: + Code: 200 diff --git a/e2e/v1/cases/015_dml_update/dbsetup/dev/PRODUCT.json b/e2e/v1/cases/015_dml_update/dbsetup/dev/PRODUCT.json new file mode 100644 index 000000000..9f31630e2 --- /dev/null +++ b/e2e/v1/cases/015_dml_update/dbsetup/dev/PRODUCT.json @@ -0,0 +1,44 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "V1 Product 2", + "VENDOR_ID": 1, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 1 + }, + + { + "ID": 3, + "NAME": "V2 Product 1", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 4, + "NAME": "V2 Product 2", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + }, + { + "ID": 5, + "NAME": "V2 Product 3", + "VENDOR_ID": 2, + "CREATED": "", + "STATUS": 1, + "USER_CREATED": 2 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/015_dml_update/dbsetup/dev/PRODUCT_JN.json b/e2e/v1/cases/015_dml_update/dbsetup/dev/PRODUCT_JN.json new file mode 100644 index 000000000..ec2649bb4 --- /dev/null +++ b/e2e/v1/cases/015_dml_update/dbsetup/dev/PRODUCT_JN.json @@ -0,0 +1,3 @@ +[ + {} +] \ No newline at end of file diff --git a/e2e/v1/cases/015_dml_update/dbsetup/dev/VENDOR.json b/e2e/v1/cases/015_dml_update/dbsetup/dev/VENDOR.json new file mode 100644 index 000000000..c4c724a74 --- /dev/null +++ b/e2e/v1/cases/015_dml_update/dbsetup/dev/VENDOR.json @@ -0,0 +1,24 @@ +[ + {}, + { + "ID": 1, + "NAME": "Vendor 1", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + }, + { + "ID": 2, + "NAME": "Vendor 2", + "ACCOUNT_ID": 101, + "CREATED": "", + "USER_CREATED": 2 + }, + { + "ID": 3, + "NAME": "Vendor 3", + "ACCOUNT_ID": 100, + "CREATED": "", + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/015_dml_update/expect.json b/e2e/v1/cases/015_dml_update/expect.json new file mode 100644 index 000000000..b963eee2e --- /dev/null +++ b/e2e/v1/cases/015_dml_update/expect.json @@ -0,0 +1,26 @@ +[ + { + "id": 2, + "name": "Vendor 2", + "products": [ + { + "@indexBy@": "id" + }, + { + "id": 3, + "name": "V2 Product 1", + "userCreated": 2 + }, + { + "id": 4, + "name": "V2 Product 2", + "userCreated": 2 + }, + { + "id": 5, + "name": "V2 Product 3", + "userCreated": 2 + } + ] + } +] \ No newline at end of file diff --git a/e2e/v1/cases/015_dml_update/expect/PRODUCT.json b/e2e/v1/cases/015_dml_update/expect/PRODUCT.json new file mode 100644 index 000000000..2138e9e62 --- /dev/null +++ b/e2e/v1/cases/015_dml_update/expect/PRODUCT.json @@ -0,0 +1,10 @@ +[ + {}, + { + "ID": 1, + "NAME": "V1 Product 1", + "VENDOR_ID": 1, + "STATUS": 2, + "USER_CREATED": 1 + } +] \ No newline at end of file diff --git a/e2e/v1/cases/015_dml_update/test.yaml b/e2e/v1/cases/015_dml_update/test.yaml new file mode 100644 index 000000000..93cee7db8 --- /dev/null +++ b/e2e/v1/cases/015_dml_update/test.yaml @@ -0,0 +1,53 @@ +init: + parentPath: $parent.path + expect: $LoadData('${parentPath}/expect.json') +pipeline: + + + signJWT: + action: secret:signJWT + privateKey: + URL: ${appPath}/e2e/cloud/jwt/private.enc + Key: blowfish://default + claims: + userID: 1 + email: dev@viantint.com + + printToken: + action: print + message: Bearer ${signJWT.TokenString} + + + test: + testNoAuthenticated: + action: http/runner:send + requests: + - Method: POST + description: user is authenticated + URL: http://127.0.0.1:8080/v1/api/shape/dev/auth/products/ + Header: + Authorization: Bearer ${signJWT.TokenString} + JSONBody: + Ids: + - 1 + Status: 2 + Expect: + Code: 200 + +# +# - Method: POST +# description: user is no authenticated +# URL: http://127.0.0.1:8080/v1/api/shape/dev/auth/products/ +# JSONBody: +# Ids: +# - 1 +# Status: 2 +# Expect: +# Code: 401 + + checkDb: + action: 'dsunit:expect' + datastore: dev + expand: true + checkPolicy: 1 + URL: ${parentPath}/expect diff --git a/e2e/v1/cases/016_dml_delete/test.yaml b/e2e/v1/cases/016_dml_delete/test.yaml new file mode 100644 index 000000000..4b90699b2 --- /dev/null +++ b/e2e/v1/cases/016_dml_delete/test.yaml @@ -0,0 +1,16 @@ +pipeline: + test: + description: delete executor removes the targeted team row by path parameter + action: http/runner:send + requests: + - Method: DELETE + URL: http://127.0.0.1:8080/v1/api/shape/dev/team/1000000 + Expect: + Code: 200 + + checkDb: + action: dsunit:query + datastore: dev + SQL: 'SELECT COUNT(*) AS NUM_RECORDS FROM (SELECT 1 FROM TEAM WHERE ID = 1000000) T' + expect: + - NUM_RECORDS: 0 diff --git a/e2e/v1/cases/017_kind_variables/expect.json b/e2e/v1/cases/017_kind_variables/expect.json new file mode 100644 index 000000000..062d072e7 --- /dev/null +++ b/e2e/v1/cases/017_kind_variables/expect.json @@ -0,0 +1,7 @@ +[ + { + "key1": "setting1 - VENDOR", + "key2": "setting2 - PRODUCT", + "key3": true + } +] diff --git a/e2e/v1/cases/017_kind_variables/test.yaml b/e2e/v1/cases/017_kind_variables/test.yaml new file mode 100644 index 000000000..ee1b80edd --- /dev/null +++ b/e2e/v1/cases/017_kind_variables/test.yaml @@ -0,0 +1,14 @@ +init: + parentPath: $parent.path + expect: $LoadJSON('${parentPath}/expect.json') + +pipeline: + test: + description: constant variables are substituted into SQL expressions before execution + action: http/runner:send + requests: + - Method: GET + URL: http://127.0.0.1:8080/v1/api/shape/dev/ws/vars/ + Expect: + Code: 200 + JSONBody: $expect diff --git a/e2e/v1/cases/018_exec_index_by/dbsetup/dev/TEAM.json b/e2e/v1/cases/018_exec_index_by/dbsetup/dev/TEAM.json new file mode 100644 index 000000000..aa65096e6 --- /dev/null +++ b/e2e/v1/cases/018_exec_index_by/dbsetup/dev/TEAM.json @@ -0,0 +1,18 @@ +[ + {}, + { + "ID": 1, + "NAME": "Team - 1", + "ACTIVE": true + }, + { + "ID": 2, + "NAME": "Team - 2", + "ACTIVE": true + }, + { + "ID": 3, + "NAME": "Team - 3", + "ACTIVE": true + } +] diff --git a/e2e/v1/cases/018_exec_index_by/dbsetup/dev/USER_TEAM.json b/e2e/v1/cases/018_exec_index_by/dbsetup/dev/USER_TEAM.json new file mode 100644 index 000000000..35c2f6810 --- /dev/null +++ b/e2e/v1/cases/018_exec_index_by/dbsetup/dev/USER_TEAM.json @@ -0,0 +1,13 @@ +[ + {}, + { + "ID": 1, + "USER_ID": 1, + "TEAM_ID": 1 + }, + { + "ID": 2, + "USER_ID": 2, + "TEAM_ID": 1 + } +] diff --git a/e2e/v1/cases/018_exec_index_by/expect/TEAM.json b/e2e/v1/cases/018_exec_index_by/expect/TEAM.json new file mode 100644 index 000000000..308755ef4 --- /dev/null +++ b/e2e/v1/cases/018_exec_index_by/expect/TEAM.json @@ -0,0 +1,18 @@ +[ + {}, + { + "ID": 1, + "NAME": "Team - 1", + "ACTIVE": true + }, + { + "ID": 2, + "NAME": "Team - 2", + "ACTIVE": true + }, + { + "ID": 3, + "NAME": "Team - 3", + "ACTIVE": false + } +] diff --git a/e2e/v1/cases/018_exec_index_by/test.yaml b/e2e/v1/cases/018_exec_index_by/test.yaml new file mode 100644 index 000000000..2211c4b64 --- /dev/null +++ b/e2e/v1/cases/018_exec_index_by/test.yaml @@ -0,0 +1,33 @@ +init: + parentPath: $parent.path + +pipeline: + test: + description: indexed view state is used to validate team membership before executor updates + action: http/runner:send + requests: + - Method: PUT + URL: http://127.0.0.1:8080/v1/api/shape/dev/teams?TeamIDs=100 + Expect: + Code: 400 + Body: + message: "not found team with ID 100" + + - Method: PUT + URL: http://127.0.0.1:8080/v1/api/shape/dev/teams?TeamIDs=1 + Expect: + Code: 400 + Body: + message: "can't deactivate team Team - 1 with 2 members" + + - Method: PUT + URL: http://127.0.0.1:8080/v1/api/shape/dev/teams?TeamIDs=3 + Expect: + Code: 200 + + checkDb: + action: dsunit:expect + datastore: dev + expand: true + checkPolicy: 1 + URL: ${parentPath}/expect diff --git a/e2e/v1/cases/019_component_dependency/expect_admin.json b/e2e/v1/cases/019_component_dependency/expect_admin.json new file mode 100644 index 000000000..86326dd5a --- /dev/null +++ b/e2e/v1/cases/019_component_dependency/expect_admin.json @@ -0,0 +1,9 @@ +[ + { + "@indexBy@": "id" + }, + { + "id": 2, + "name": "Vendor 2" + } +] diff --git a/e2e/v1/cases/019_component_dependency/expect_readonly.json b/e2e/v1/cases/019_component_dependency/expect_readonly.json new file mode 100644 index 000000000..02513fe0b --- /dev/null +++ b/e2e/v1/cases/019_component_dependency/expect_readonly.json @@ -0,0 +1,9 @@ +[ + { + "@indexBy@": "id" + }, + { + "id": 1, + "name": "Vendor 1" + } +] diff --git a/e2e/v1/cases/019_component_dependency/test.yaml b/e2e/v1/cases/019_component_dependency/test.yaml new file mode 100644 index 000000000..d887709d6 --- /dev/null +++ b/e2e/v1/cases/019_component_dependency/test.yaml @@ -0,0 +1,49 @@ +init: + parentPath: $parent.path + expectReadOnly: $LoadData('${parentPath}/expect_readonly.json') + expectAdmin: $LoadData('${parentPath}/expect_admin.json') + +pipeline: + signReadOnly: + action: secret:signJWT + privateKey: + URL: ${appPath}/e2e/cloud/jwt/private.enc + Key: blowfish://default + claims: + userID: 1 + firstName: Tester + email: tester@viantint.com + + testReadOnly: + action: http/runner:send + requests: + - Method: GET + description: component dependency resolves a sibling DQL component and gates vendor access for read-only users + URL: http://127.0.0.1:8080/v1/api/shape/dev/vendors/component-acl + Header: + Authorization: Bearer ${signReadOnly.TokenString} + Expect: + Code: 200 + JSONBody: $expectReadOnly + + signAdmin: + action: secret:signJWT + privateKey: + URL: ${appPath}/e2e/cloud/jwt/private.enc + Key: blowfish://default + claims: + userID: 2 + firstName: Developer + email: dev@viantint.com + + testAdmin: + action: http/runner:send + requests: + - Method: GET + description: component dependency keeps query predicate filtering for non-read-only users + URL: http://127.0.0.1:8080/v1/api/shape/dev/vendors/component-acl?name=2 + Header: + Authorization: Bearer ${signAdmin.TokenString} + Expect: + Code: 200 + JSONBody: $expectAdmin diff --git a/e2e/v1/cases/020_generate_patch_basic_one/dbsetup/dev/FOOS.json b/e2e/v1/cases/020_generate_patch_basic_one/dbsetup/dev/FOOS.json new file mode 100644 index 000000000..56cc97b53 --- /dev/null +++ b/e2e/v1/cases/020_generate_patch_basic_one/dbsetup/dev/FOOS.json @@ -0,0 +1,23 @@ +[ + {}, + { + "ID": 1, + "NAME": "foo 1", + "QUANTITY": 100 + }, + { + "ID": 2, + "NAME": "foo 2", + "QUANTITY": 200 + }, + { + "ID": 3, + "NAME": "foo 3", + "QUANTITY": 300 + }, + { + "ID": 4, + "NAME": "foo 4", + "QUANTITY": 400 + } +] diff --git a/e2e/v1/cases/020_generate_patch_basic_one/expect_t0.json b/e2e/v1/cases/020_generate_patch_basic_one/expect_t0.json new file mode 100644 index 000000000..a5feed5e4 --- /dev/null +++ b/e2e/v1/cases/020_generate_patch_basic_one/expect_t0.json @@ -0,0 +1,5 @@ +{ + "id": 4, + "quantity": 2500, + "name": "changed - foo 4" +} diff --git a/e2e/v1/cases/020_generate_patch_basic_one/expect_t1.json b/e2e/v1/cases/020_generate_patch_basic_one/expect_t1.json new file mode 100644 index 000000000..caed750cf --- /dev/null +++ b/e2e/v1/cases/020_generate_patch_basic_one/expect_t1.json @@ -0,0 +1,5 @@ +{ + "id": "@exists@", + "quantity": 1234, + "name": "created" +} diff --git a/e2e/v1/cases/020_generate_patch_basic_one/test.yaml b/e2e/v1/cases/020_generate_patch_basic_one/test.yaml new file mode 100644 index 000000000..ecfa2612c --- /dev/null +++ b/e2e/v1/cases/020_generate_patch_basic_one/test.yaml @@ -0,0 +1,41 @@ +init: + parentPath: $parent.path + +pipeline: + test: + description: generated PATCH component updates an existing FOOS row and inserts a new one when the primary key is absent + action: http/runner:send + requests: + - Method: PATCH + URL: http://127.0.0.1:8080/v1/api/shape/dev/basic/foos + JsonBody: + ID: 4 + Quantity: 2500 + Name: 'changed - foo 4' + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_t0.json') + + - Method: PATCH + URL: http://127.0.0.1:8080/v1/api/shape/dev/basic/foos + JsonBody: + Quantity: 1234 + Name: 'created' + Expect: + Code: 200 + JSONBody: $LoadJSON('${parentPath}/expect_t1.json') + + checkDb: + action: dsunit:query + datastore: dev + sql: | + SELECT + (CASE WHEN EXISTS(SELECT 1 FROM FOOS WHERE ID = 4 AND QUANTITY = 2500 AND NAME = 'changed - foo 4') THEN TRUE ELSE FALSE END) AS UPDATED_ROW, + (CASE WHEN EXISTS(SELECT 1 FROM FOOS WHERE ID <> 4 AND QUANTITY = 1234 AND NAME = 'created') THEN TRUE ELSE FALSE END) AS INSERTED_ROW, + (SELECT COUNT(*) FROM FOOS) AS TOTAL_ROWS, + (SELECT COUNT(*) FROM FOOS WHERE NAME = 'created' AND QUANTITY = 1234) AS INSERTED_COUNT + expect: + - UPDATED_ROW: true + INSERTED_ROW: true + TOTAL_ROWS: 11 + INSERTED_COUNT: 1 diff --git a/e2e/v1/config.json b/e2e/v1/config.json new file mode 100644 index 000000000..fc6d24787 --- /dev/null +++ b/e2e/v1/config.json @@ -0,0 +1,18 @@ +{ + "APIPrefix": "/v1/api", + "APIKeys": [ + { + "Header": "App-Secret-Id", + "URI": "/v1/api/shape/dev/secured", + "Value": "changeme" + } + ], + "Endpoint": { + "Port": 8080 + }, + "Meta": { + "StatusURI": "/v1/api/status", + "StructURI": "/v1/api/meta/struct" + }, + "SyncFrequencyMs": 3600000 +} diff --git a/e2e/v1/datastore.yaml b/e2e/v1/datastore.yaml new file mode 100644 index 000000000..08160b622 --- /dev/null +++ b/e2e/v1/datastore.yaml @@ -0,0 +1,27 @@ +init: + +pipeline: + mysql: + create: + action: dsunit:init + datastore: dev + recreate: false + config: + driverName: mysql + descriptor: '[username]:[password]@tcp(${dbIP.mysql}:3306)/[dbname]?parseTime=true' + credentials: $mysqlCredentials + admin: + datastore: mysql + ping: true + config: + driverName: mysql + descriptor: '[username]:[password]@tcp(${dbIP.mysql}:3306)/[dbname]?parseTime=true' + credentials: $mysqlCredentials + scripts: + - URL: ${v1Path}/db/schema.sql + + prepare: + action: 'dsunit:prepare' + datastore: dev + expand: true + URL: ${appPath}/e2e/local/datastore/mysql/populate diff --git a/e2e/v1/db/schema.sql b/e2e/v1/db/schema.sql new file mode 100644 index 000000000..f6192f9aa --- /dev/null +++ b/e2e/v1/db/schema.sql @@ -0,0 +1,227 @@ +SET GLOBAL log_bin_trust_function_creators = 1; +SET GLOBAL sql_mode = ''; + +DROP TABLE IF EXISTS USER; +CREATE TABLE USER ( + ID INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(255), + MGR_ID INT, + ACCOUNT_ID INT +); + +DROP TABLE IF EXISTS VENDOR; +CREATE TABLE VENDOR ( + ID INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(255), + ACCOUNT_ID INT, + CREATED DATETIME, + USER_CREATED INT, + UPDATED DATETIME, + USER_UPDATED INT +); + +DROP TABLE IF EXISTS PRODUCT; + +CREATE TABLE PRODUCT ( + ID INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(255), + VENDOR_ID INT, + STATUS INT, + CREATED DATETIME, + USER_CREATED INT, + UPDATED DATETIME, + USER_UPDATED INT +); + +DROP TABLE IF EXISTS PRODUCT_JN; + +CREATE TABLE PRODUCT_JN ( + PRODUCT_ID INT NOT NULL, + USER_ID INT, + OLD_VALUE VARCHAR(255), + NEW_VALUE VARCHAR(255), + CREATED DATETIME +); + +DROP FUNCTION IF EXISTS IS_VENDOR_AUTHORIZED; + +DELIMITER $$ +CREATE FUNCTION IS_VENDOR_AUTHORIZED(USER_ID INT, VENDOR_ID INT) + RETURNS BOOLEAN +BEGIN + DECLARE +IS_AUTH BOOLEAN; +SELECT TRUE +INTO IS_AUTH +FROM VENDOR v +WHERE ID = VENDOR_ID + AND ACCOUNT_ID + AND EXISTS(SELECT 1 FROM USER u WHERE u.ID = USER_ID AND u.ACCOUNT_ID = v.ACCOUNT_ID); +RETURN IS_AUTH; +END $$ +DELIMITER; + + +DROP FUNCTION IF EXISTS IS_PRODUCT_AUTHORIZED; + +DELIMITER $$ +CREATE FUNCTION IS_PRODUCT_AUTHORIZED(USER_ID INT, PID INT) + RETURNS BOOLEAN +BEGIN + DECLARE +IS_AUTH BOOLEAN; + SET +IS_AUTH = FALSE ; +SELECT TRUE +INTO IS_AUTH +FROM VENDOR v + JOIN PRODUCT p ON v.ID = p.VENDOR_ID +WHERE p.ID = PID + AND ACCOUNT_ID + AND EXISTS(SELECT 1 + FROM USER u + WHERE u.ID = USER_ID + AND u.ACCOUNT_ID = v.ACCOUNT_ID); +RETURN IS_AUTH; +END $$ +DELIMITER; + + +DROP TABLE IF EXISTS DISTRICT; +CREATE TABLE DISTRICT ( + ID INT PRIMARY KEY, + NAME VARCHAR(255) +); + +DROP TABLE IF EXISTS CITY; +CREATE TABLE CITY ( + ID INT PRIMARY KEY, + NAME varchar(255), + ZIP_CODE varchar(255), + DISTRICT_ID INT +); + +DROP TABLE IF EXISTS TEAM; +CREATE TABLE TEAM ( + ID INT PRIMARY KEY, + NAME varchar(255), + ACTIVE INTEGER +); + +DROP TABLE IF EXISTS USER_TEAM; +CREATE TABLE USER_TEAM ( + ID INT PRIMARY KEY, + USER_ID INT, + TEAM_ID INT +); + +DROP TABLE IF EXISTS EVENTS; +CREATE TABLE EVENTS ( + ID INT AUTO_INCREMENT PRIMARY KEY, + NAME varchar(255), + QUANTITY INT +); + +DROP TABLE IF EXISTS EVENTS_PERFORMANCE; +CREATE TABLE EVENTS_PERFORMANCE +( + ID INT AUTO_INCREMENT PRIMARY KEY, + PRICE INT, + EVENT_ID INT, + TIMESTAMP DATE, + FOREIGN KEY (EVENT_ID) REFERENCES EVENTS (ID) +); + +DROP TABLE IF EXISTS FOOS; +CREATE TABLE FOOS ( + ID INT AUTO_INCREMENT PRIMARY KEY, + NAME varchar(255), + QUANTITY INT +); + +DROP TABLE IF EXISTS FOOS_CHANGES; +CREATE TABLE FOOS_CHANGES ( + ID INT AUTO_INCREMENT PRIMARY KEY, + PREVIOUS TEXT +); + +DROP TABLE IF EXISTS FOOS_PERFORMANCE; +CREATE TABLE FOOS_PERFORMANCE ( + ID INT AUTO_INCREMENT PRIMARY KEY, + PERF_NAME varchar(255), + PERF_QUANTITY INT, + FOO_ID INT, + FOREIGN KEY (FOO_ID) REFERENCES FOOS(ID) +); + +DROP TABLE IF EXISTS DIFF_JN; +CREATE TABLE DIFF_JN ( + ID INT AUTO_INCREMENT PRIMARY KEY, + DIFF LONGTEXT +); + +DROP TABLE IF EXISTS USER_METADATA; +CREATE TABLE USER_METADATA ( + ID INT AUTO_INCREMENT PRIMARY KEY, + USER_ID INT, + IS_ENABLED BIT, + IS_ACTIVATED BIT, + FOREIGN KEY (USER_ID) REFERENCES USER (ID) +); + +DROP TABLE IF EXISTS OBJECTS; +CREATE TABLE OBJECTS ( + ID INT AUTO_INCREMENT PRIMARY KEY, + OBJECT TEXT, + CLASS_NAME VARCHAR(255) +); + +DROP TABLE IF EXISTS BAR; +CREATE TABLE BAR ( + ID INT AUTO_INCREMENT PRIMARY KEY, + NAME varchar(255), + PRICE DOUBLE PRECISION, + TAX FLOAT +); + +DROP TABLE IF EXISTS DATLY_JOBS; + +CREATE TABLE `DATLY_JOBS` ( + `MatchKey` varchar(3000) NOT NULL, + `Status` varchar(40) NOT NULL, + `Metrics` text NOT NULL, + `Connector` varchar(256), + `TableName` varchar(256), + `TableDataset` varchar(256), + `TableSchema` varchar(256), + `CreateDisposition` varchar(256), + `Template` varchar(256), + `WriteDisposition` varchar(256), + `Cache` text, + `CacheKey` varchar(256), + `CacheSet` varchar(256), + `CacheNamespace` varchar(256), + `Method` varchar(256) NOT NULL, + `URI` varchar(256) NOT NULL, + `State` text NOT NULL, + `UserEmail` varchar(256), + `UserID` varchar(256), + `MainView` varchar(256) NOT NULL, + `Module` varchar(256) NOT NULL, + `Labels` varchar(256) NOT NULL, + `JobType` varchar(256) NOT NULL, + `EventURL` varchar(256) NOT NULL, + `Error` text, + `CreationTime` datetime NOT NULL, + `StartTime` datetime DEFAULT NULL, + `ExpiryTime` datetime DEFAULT NULL, + `EndTime` datetime DEFAULT NULL, + `WaitTimeInMcs` int(11) NOT NULL, + `RuntimeInMcs` int(11) NOT NULL, + `SQLQuery` text NOT NULL, + `Deactivated` tinyint(1), + `ID` varchar(40) NOT NULL, + PRIMARY KEY (`ID`) +); + +CREATE INDEX DATLY_JOBS_REF ON DATLY_JOBS(MatchKey, CreationTime, Deactivated); diff --git a/e2e/v1/dql/dev/district/district_pagination.sql b/e2e/v1/dql/dev/district/district_pagination.sql new file mode 100644 index 000000000..c927d4883 --- /dev/null +++ b/e2e/v1/dql/dev/district/district_pagination.sql @@ -0,0 +1,14 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/district/pagination') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/meta/districts', 'GET')) + +#define($_ = $IDs<[]int>(query/IDs)) +#define($_ = $Page(query/page).Optional().QuerySelector('districts')) +#define($_ = $Data(output/view).Embed()) + + +SELECT districts.*, + cities.*, + set_limit(cities, 2) +FROM (SELECT t.* FROM DISTRICT t WHERE 1 = 1 AND ID IN ($IDs)) districts +JOIN (SELECT * FROM CITY t) cities ON districts.ID = cities.DISTRICT_ID diff --git a/e2e/v1/dql/dev/events/basic_one_one.dql b/e2e/v1/dql/dev/events/basic_one_one.dql new file mode 100644 index 000000000..061949707 --- /dev/null +++ b/e2e/v1/dql/dev/events/basic_one_one.dql @@ -0,0 +1,12 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/events/relation_one_one') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/basic/events-one-one', 'GET')) + + +#set( $_ = $Data(output/view).Embed()) + + +SELECT EVENTS.*, + EVENTS_PERFORMANCE.* +FROM (SELECT ID, QUANTITY FROM EVENTS) EVENTS +JOIN (SELECT * FROM EVENTS_PERFORMANCE) EVENTS_PERFORMANCE ON EVENTS.ID = EVENTS_PERFORMANCE.EVENT_ID AND 1=1 diff --git a/e2e/v1/dql/dev/events/patch_basic_one.dql b/e2e/v1/dql/dev/events/patch_basic_one.dql new file mode 100644 index 000000000..ef5f2eb14 --- /dev/null +++ b/e2e/v1/dql/dev/events/patch_basic_one.dql @@ -0,0 +1,12 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/basic/foos', 'PATCH')) +#setting($_ = $useTemplate('patch')) + + +#set($_ = $Foos(body/).Cardinality('One').Tag('anonymous:"true"')) +#set($_ = $Foos(body/).Output().Tag('anonymous:"true"')) + + +SELECT foos.* +FROM (SELECT * FROM FOOS) foos diff --git a/e2e/v1/dql/dev/events/post_basic_many.dql b/e2e/v1/dql/dev/events/post_basic_many.dql new file mode 100644 index 000000000..04a7e59c4 --- /dev/null +++ b/e2e/v1/dql/dev/events/post_basic_many.dql @@ -0,0 +1,10 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/events/basic_many') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/basic/events-many', 'POST')) + +#set($_ = $Events(body/).Cardinality('Many').Tag('anonymous:"true"')) +#set($_ = $Events(body/).Output().Tag('anonymous:"true"')) + + +SELECT events.* +FROM (SELECT * FROM EVENTS) events diff --git a/e2e/v1/dql/dev/events/post_basic_one.dql b/e2e/v1/dql/dev/events/post_basic_one.dql new file mode 100644 index 000000000..fc96181fd --- /dev/null +++ b/e2e/v1/dql/dev/events/post_basic_one.dql @@ -0,0 +1,11 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/events/basic_one') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/basic/events', 'POST')) + + +#set($_ = $Events(body/).Cardinality('One').Tag('anonymous:"true"')) +#set($_ = $Events(body/).Output().Tag('anonymous:"true"')) + + +SELECT events.* +FROM (SELECT * FROM EVENTS) events diff --git a/e2e/v1/dql/dev/events/post_comprehensive_many.dql b/e2e/v1/dql/dev/events/post_comprehensive_many.dql new file mode 100644 index 000000000..ffe8866c0 --- /dev/null +++ b/e2e/v1/dql/dev/events/post_comprehensive_many.dql @@ -0,0 +1,12 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/events/comprehensive_many') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/comprehensive/events-many', 'POST')) + + +#set($_ = $Events(body/Data).Cardinality('Many')) +#set($_ = $Status(output/status).Tag('anonymous:"true"')) +#set($_ = $Data(body/Data).Output()) + + +SELECT events.* +FROM (SELECT * FROM EVENTS) events diff --git a/e2e/v1/dql/dev/events/post_except.dql b/e2e/v1/dql/dev/events/post_except.dql new file mode 100644 index 000000000..b8763899c --- /dev/null +++ b/e2e/v1/dql/dev/events/post_except.dql @@ -0,0 +1,11 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/events/except') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/basic/events-except', 'POST')) + +#set($_ = $Events(body/).Cardinality('One').Tag('anonymous:"true"')) +#set($_ = $Events(body/).Output().Tag('anonymous:"true"')) + + + +SELECT events.* EXCEPT NAME +FROM (SELECT * FROM EVENTS) events diff --git a/e2e/v1/dql/dev/team/team.dql b/e2e/v1/dql/dev/team/team.dql new file mode 100644 index 000000000..6c1320af2 --- /dev/null +++ b/e2e/v1/dql/dev/team/team.dql @@ -0,0 +1,5 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/team/delete') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/team/{teamID}', 'DELETE')) + +DELETE FROM TEAM WHERE ID = ${teamID} diff --git a/e2e/v1/dql/dev/team/user_team.dql b/e2e/v1/dql/dev/team/user_team.dql new file mode 100644 index 000000000..0f9f5480c --- /dev/null +++ b/e2e/v1/dql/dev/team/user_team.dql @@ -0,0 +1,40 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/team/user_team') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/teams', 'PUT')) + +#define($_ = $TeamIDs<[]int>(query/TeamIDs)) + +#set($teamStatsIndex = $Unsafe.TeamStats.IndexBy("Id")) + +#define($_ = $TeamStats(view/team_stats). + WithColumnType('ID', 'int'). + WithColumnType('TEAM_MEMBERS', 'int'). + WithColumnType('NAME', 'string') /* + SELECT + t.ID, + ( + CASE + WHEN ut.TEAM_ID IS NULL THEN 0 + ELSE COUNT(1) + END + ) as TEAM_MEMBERS, + t.NAME as NAME + FROM TEAM t + LEFT JOIN USER_TEAM ut ON t.ID = ut.TEAM_ID + WHERE t.ID IN ($TeamIDs) + GROUP BY t.ID +*/) + +#foreach($teamID in $Unsafe.TeamIDs) + #if($teamStatsIndex.HasKey($teamID) == false) + $logger.FatalfWithCode(400, "not found team with ID %v", $teamID) + #end + + #set($aTeam = $teamStatsIndex[$teamID]) + #if($aTeam.TeamMembers != 0) + $logger.FatalfWithCode(400, "can't deactivate team %v with %v members", $aTeam.Name, $aTeam.TeamMembers) + #end +UPDATE TEAM +SET ACTIVE = false +WHERE ID = $teamID; +#end diff --git a/e2e/v1/dql/dev/user/user_metadata.dql b/e2e/v1/dql/dev/user/user_metadata.dql new file mode 100644 index 000000000..0f6c71def --- /dev/null +++ b/e2e/v1/dql/dev/user/user_metadata.dql @@ -0,0 +1,10 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/user/mysql_boolean') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/user-metadata', 'GET')) + +#define($_ = $Fields<[]string>(query/fields).Optional().QuerySelector('user_metadata')) +#define($_ = $Page(query/page).Optional().QuerySelector('user_metadata')) +#define($_ = $UserMetadata(output/view).Embed()) + +SELECT user_metadata.* +FROM (SELECT * FROM USER_METADATA t) user_metadata diff --git a/e2e/v1/dql/dev/user/user_tree.sql b/e2e/v1/dql/dev/user/user_tree.sql new file mode 100644 index 000000000..33b8e13c1 --- /dev/null +++ b/e2e/v1/dql/dev/user/user_tree.sql @@ -0,0 +1,10 @@ +#package('github.com/viant/datly/e2e/v1/shape/dev/user/tree') +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/users/', 'GET')) + +#define($_ = $Data(output/view)) +#define($_ = $Status(output/status).Embed()) + +SELECT user.* EXCEPT MGR_ID, + self_ref(user, 'Team', 'ID', 'MGR_ID') +FROM (SELECT t.* FROM USER t ) user diff --git a/e2e/v1/go_bootstrap.yaml b/e2e/v1/go_bootstrap.yaml new file mode 100644 index 000000000..aa91fcf2d --- /dev/null +++ b/e2e/v1/go_bootstrap.yaml @@ -0,0 +1,227 @@ +init: + wildcardConfig: ${v1Path}/autogen/Datly/config_go_all.json + singleConfig: ${v1Path}/autogen/Datly/config_go_vendor_list.json + routerAppDir: ${v1Path}/routerapp + routerPkgDir: ${v1Path}/routerpkg + routerImports: ${routerAppDir}/imports_gen.go + goBootstrapPort: 8081 + wildcardExpect: $LoadData('${v1Path}/cases/001_relation_one_to_many/expect.json') + adminExpect: $LoadData('${v1Path}/cases/019_component_dependency/expect_admin.json') + +pipeline: + prepare: + action: exec:run + target: $target + checkError: true + commands: + - cd ${appPath} + - mkdir -p ${routerAppDir} + - mkdir -p ${v1Path}/autogen/Datly + - printf "package main\n\nimport (\n" > ${routerImports} + - | + find ${v1Path}/shape ${routerPkgDir} -type f -name '*.go' ! -name '*_test.go' -print \ + | xargs -n1 dirname \ + | sort -u \ + | sed "s#^${appPath}/#github.com/viant/datly/#" \ + | sed 's#^#\t_ "#; s#$#"#' >> ${routerImports} + - printf ")\n" >> ${routerImports} + - | + cat > ${wildcardConfig} <<'EOF' + { + "APIKeys": [ + { + "Header": "App-Secret-Id", + "URI": "/v1/api/shape/dev/secured", + "Value": "changeme" + } + ], + "APIPrefix": "/v1/api/shape", + "GoBootstrap": { + "Packages": [ + "github.com/viant/datly/e2e/v1/shape/dev/...", + "github.com/viant/datly/e2e/v1/routerpkg/..." + ] + }, + "DependencyURL": "${v1Path}/autogen/Datly/dependencies", + "Endpoint": { + "Port": ${goBootstrapPort} + }, + "JWTValidator": { + "RSA": [ + { + "Key": "blowfish://default", + "URL": "file://localhost${appPath}/e2e/local/jwt/public.enc" + } + ] + }, + "JwtSigner": { + "RSA": { + "URL": "file://localhost${appPath}/e2e/local/jwt/public.enc", + "Key": "blowfish://default" + } + }, + "Meta": { + "CacheURI": "/v1/api/cache/warmup", + "ConfigURI": "/v1/api/meta/config", + "MetricURI": "/v1/api/meta/metric", + "OpenApiURI": "/v1/api/meta/openapi", + "StateURI": "/v1/api/meta/state", + "StatusURI": "/v1/api/status", + "StructURI": "/v1/api/meta/struct", + "ViewURI": "/v1/api/meta/view" + }, + "SyncFrequencyMs": 2000 + } + EOF + - | + cat > ${singleConfig} <<'EOF' + { + "APIKeys": [ + { + "Header": "App-Secret-Id", + "URI": "/v1/api/shape/dev/secured", + "Value": "changeme" + } + ], + "APIPrefix": "/v1/api/shape", + "GoBootstrap": { + "Packages": [ + "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/list" + ] + }, + "DependencyURL": "${v1Path}/autogen/Datly/dependencies", + "Endpoint": { + "Port": ${goBootstrapPort} + }, + "JWTValidator": { + "RSA": [ + { + "Key": "blowfish://default", + "URL": "file://localhost${appPath}/e2e/local/jwt/public.enc" + } + ] + }, + "JwtSigner": { + "RSA": { + "URL": "file://localhost${appPath}/e2e/local/jwt/public.enc", + "Key": "blowfish://default" + } + }, + "Meta": { + "CacheURI": "/v1/api/cache/warmup", + "ConfigURI": "/v1/api/meta/config", + "MetricURI": "/v1/api/meta/metric", + "OpenApiURI": "/v1/api/meta/openapi", + "StateURI": "/v1/api/meta/state", + "StatusURI": "/v1/api/status", + "StructURI": "/v1/api/meta/struct", + "ViewURI": "/v1/api/meta/view" + }, + "SyncFrequencyMs": 2000 + } + EOF + + build: + action: exec:run + target: $target + checkError: true + commands: + - cd ${appPath} + - export GO111MODULE=on + - export GOFLAGS=-mod=mod + - go build -ldflags "-X main.BuildTimeInS=`date +%s`" -o /tmp/datly_go_router ./e2e/v1/routerapp + + wildcardApp: + stop: + action: process:stop + target: $target + input: datly_go_router + + start: + action: process:start + sleepTimeMs: 6000 + target: $target + directory: /tmp/ + checkError: true + immuneToHangups: true + env: + TEST: 1 + command: ulimit -Sn 10000 && ./datly_go_router -c=${wildcardConfig} > /tmp/datly_v1_go_all.out 2>&1 + + wildcardJWT: + action: secret:signJWT + privateKey: + URL: ${appPath}/e2e/cloud/jwt/private.enc + Key: blowfish://default + claims: + userID: 2 + firstName: Developer + email: dev@viantint.com + + wildcardTest: + action: http/runner:send + requests: + - Method: GET + description: wildcard GoBootstrap loads vendor list route from generated router packages + URL: http://127.0.0.1:${goBootstrapPort}/v1/api/shape/dev/vendors/ + Expect: + Code: 200 + JSONBody: $wildcardExpect + - Method: GET + description: wildcard GoBootstrap loads cross-package component dependency route + URL: http://127.0.0.1:${goBootstrapPort}/v1/api/shape/dev/vendors/component-acl?name=2 + Header: + Authorization: Bearer ${wildcardJWT.TokenString} + Expect: + Code: 200 + JSONBody: $adminExpect + - Method: GET + description: wildcard GoBootstrap loads linked handler route from linked Go package + URL: http://127.0.0.1:${goBootstrapPort}/v1/api/shape/dev/linked/auth?echo=hello + Header: + Authorization: Bearer ${wildcardJWT.TokenString} + Expect: + Code: 200 + JSONBody: + status: ok + data: + userID: 2 + firstName: Developer + echo: hello + + singleApp: + stop: + action: process:stop + target: $target + input: datly_go_router + + start: + action: process:start + sleepTimeMs: 6000 + target: $target + directory: /tmp/ + checkError: true + immuneToHangups: true + env: + TEST: 1 + command: ulimit -Sn 10000 && ./datly_go_router -c=${singleConfig} > /tmp/datly_v1_go_one.out 2>&1 + + singleTest: + action: http/runner:send + requests: + - Method: GET + description: single-package GoBootstrap loads only the requested vendor/list package + URL: http://127.0.0.1:${goBootstrapPort}/v1/api/shape/dev/vendors/ + Expect: + Code: 200 + JSONBody: $wildcardExpect + - Method: GET + description: single-package GoBootstrap does not load unrelated user ACL route + URL: http://127.0.0.1:${goBootstrapPort}/v1/api/shape/dev/auth/user-acl + Expect: + Code: 404 + + cleanup: + action: process:stop + target: $target + input: datly_go_router diff --git a/e2e/v1/regression/app.yaml b/e2e/v1/regression/app.yaml new file mode 100644 index 000000000..50f8c3824 --- /dev/null +++ b/e2e/v1/regression/app.yaml @@ -0,0 +1,17 @@ +pipeline: + datly: + stop: + action: process:stop + target: $target + input: datly + + start: + action: process:start + sleepTimeMs: 6000 + target: $target + directory: /tmp/ + checkError: true + immuneToHangups: true + env: + TEST: 1 + command: pkill -f '${v1Path}/autogen/Datly/config.json -z=/tmp/jobs/datly_v1' >/dev/null 2>&1 || true; ulimit -Sn 10000 && ./datly -c=${v1Path}/autogen/Datly/config.json -z=/tmp/jobs/datly_v1 --mcpPort=8281 > /tmp/datly_v1.out 2>&1 diff --git a/e2e/v1/regression/db.yaml b/e2e/v1/regression/db.yaml new file mode 100644 index 000000000..01c96ed33 --- /dev/null +++ b/e2e/v1/regression/db.yaml @@ -0,0 +1,9 @@ +pipeline: + register: + action: dsunit:register + datastore: dev + recreate: false + config: + driverName: mysql + descriptor: '[username]:[password]@tcp(${dbIP.mysql}:3306)/[dbname]?parseTime=true' + credentials: $mysqlCredentials diff --git a/e2e/v1/regression/regression.yaml b/e2e/v1/regression/regression.yaml new file mode 100644 index 000000000..56c85ba6d --- /dev/null +++ b/e2e/v1/regression/regression.yaml @@ -0,0 +1,36 @@ +init: + +pipeline: + database: + action: run + request: '@db.yaml' + + app: + when: $debugger!=on + description: start datly app (DQLBootstrap loads DQL from dql/ folder) + action: run + request: '@app' + + test: + tag: $pathMatch + data: + '[]dev_dbsetup': '@dbsetup/dev' + + subPath: '../cases/${index}_*' + range: 1..022 + template: + checkSkip: + action: nop + comments: use case init + skip: $HasResource(${path}/skip.txt) + + dbsetup: + when: $Len($dev_dbsetup) > 0 + action: 'dsunit:prepare' + datastore: dev + expand: true + data: $dev_dbsetup + + test: + action: run + request: '@test' diff --git a/e2e/v1/routerapp/main.go b/e2e/v1/routerapp/main.go new file mode 100644 index 000000000..5e4915f77 --- /dev/null +++ b/e2e/v1/routerapp/main.go @@ -0,0 +1,43 @@ +package main + +import ( + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/viant/afs/embed" + _ "github.com/viant/afsc/gs" + _ "github.com/viant/bigquery" + "github.com/viant/datly/cmd" + "github.com/viant/datly/cmd/env" + "github.com/viant/datly/service/executor/expand" + _ "github.com/viant/scy/kms/blowfish" + _ "github.com/viant/sqlx/metadata/product/mysql" + _ "github.com/viant/sqlx/metadata/product/pg" + _ "github.com/viant/sqlx/metadata/product/sqlite" + "os" + "strconv" + "time" +) + +var ( + Version = "development" + BuildTimeInS string +) + +func init() { + os.Setenv("DATLY_NOPANIC", "true") + expand.SetPanicOnError(false) + + if BuildTimeInS != "" { + seconds, err := strconv.Atoi(BuildTimeInS) + if err != nil { + panic(err) + } + env.BuildTime = time.Unix(int64(seconds), 0) + } +} + +func main() { + if err := cmd.RunApp(Version, os.Args[1:]); err != nil { + panic(err) + } +} diff --git a/e2e/v1/routerpkg/dev/linkedauth/handler.go b/e2e/v1/routerpkg/dev/linkedauth/handler.go new file mode 100644 index 000000000..88e80ed3d --- /dev/null +++ b/e2e/v1/routerpkg/dev/linkedauth/handler.go @@ -0,0 +1,61 @@ +package linkedauth + +import ( + "context" + "reflect" + + "github.com/viant/scy/auth/jwt" + "github.com/viant/xdatly" + xhandler "github.com/viant/xdatly/handler" + "github.com/viant/xdatly/handler/response" + "github.com/viant/xdatly/types/core" + "github.com/viant/xdatly/types/custom/dependency/checksum" +) + +const packageName = "github.com/viant/datly/e2e/v1/routerpkg/dev/linkedauth" + +func init() { + core.RegisterType(packageName, "LinkedAuthInput", reflect.TypeOf(LinkedAuthInput{}), checksum.GeneratedTime) + core.RegisterType(packageName, "LinkedAuthOutput", reflect.TypeOf(LinkedAuthOutput{}), checksum.GeneratedTime) + core.RegisterType(packageName, "LinkedAuthPayload", reflect.TypeOf(LinkedAuthPayload{}), checksum.GeneratedTime) + core.RegisterType(packageName, "Handler", reflect.TypeOf(Handler{}), checksum.GeneratedTime) +} + +type LinkedAuthInput struct { + Jwt *jwt.Claims `parameter:",kind=header,in=Authorization,dataType=string,errorCode=401" codec:"JwtClaim"` + Echo string `parameter:",kind=query,in=echo"` +} + +type LinkedAuthPayload struct { + UserID int `json:"userID"` + FirstName string `json:"firstName,omitempty"` + Echo string `json:"echo,omitempty"` +} + +type LinkedAuthOutput struct { + response.Status `parameter:",kind=output,in=status" json:",omitempty"` + Data *LinkedAuthPayload `parameter:",kind=output,in=view" json:"data,omitempty"` +} + +type Handler struct{} + +func (h *Handler) Exec(ctx context.Context, sess xhandler.Session) (interface{}, error) { + input := &LinkedAuthInput{} + if err := sess.Stater().Bind(ctx, input); err != nil { + return nil, err + } + if input.Jwt == nil { + return nil, response.NewError(401, "unauthorized access") + } + return &LinkedAuthOutput{ + Data: &LinkedAuthPayload{ + UserID: input.Jwt.UserID, + FirstName: input.Jwt.FirstName, + Echo: input.Echo, + }, + }, nil +} + +type LinkedAuthRouter struct { + LinkedAuth xdatly.Component[LinkedAuthInput, LinkedAuthOutput] `component:",path=/v1/api/shape/dev/linked/auth,method=GET,connector=dev,input=LinkedAuthInput,output=LinkedAuthOutput,handler=Handler"` +} diff --git a/e2e/v1/run.yaml b/e2e/v1/run.yaml new file mode 100644 index 000000000..e306b39d4 --- /dev/null +++ b/e2e/v1/run.yaml @@ -0,0 +1,48 @@ +init: + yesterday: $FormatTime('yesterdayInUTC', 'yyyy-MM-dd HH:mm:ss') + today: $FormatTime('nowInUTC', 'yyyy-MM-dd HH:mm:ss') + debugger: '$params.debugger?$params.debugger:0' + + target: + URL: ssh://localhost/ + credentials: localhost + appPath: $WorkingDirectory(../..) + v1Path: ${appPath}/e2e/v1 + mysqlCredentials: mysql-e2e + dbIP: + mysql: localhost + qMark: '?' + connectors: --connector 'dev|mysql|root:dev@tcp(${dbIP.mysql}:3306)/dev${qMark}parseTime=true' + +pipeline: + init: + description: initialise test (docker, database) + system: + action: run + request: '@system' + tasks: '*' + + datastore: + action: run + request: '@datastore' + tasks: '*' + + shapes: + description: generate Go types and route YAML from DQL + action: run + request: '@shapes' + + build: + action: run + request: '@build' + tasks: 'deploy' + + test: + action: run + description: run v1 regression test + request: '@regression/regression' + + goBootstrap: + action: run + description: run linked Go router bootstrap regression test + request: '@go_bootstrap' diff --git a/e2e/v1/shapes.yaml b/e2e/v1/shapes.yaml new file mode 100644 index 000000000..e50ed79c6 --- /dev/null +++ b/e2e/v1/shapes.yaml @@ -0,0 +1,54 @@ +init: + shapePath: ${v1Path}/shape + shapeDevPath: ${shapePath}/dev + repoPath: ${v1Path}/autogen + routerImports: ${v1Path}/routerapp/imports_gen.go + conn: -c='dev|mysql|root:dev@tcp(${dbIP.mysql}:3306)/dev${qMark}parseTime=true' + jwt: -J='${appPath}/e2e/local/jwt/public.enc|blowfish://default' + api: -a='/v1/api/shape' + mode: --skip-yaml + +pipeline: + + cleanup: + action: exec:run + description: clean up generated shapes and routes + target: '$target' + checkError: true + commands: + - mkdir -p ${repoPath} + - rm -rf ${repoPath} + - mkdir -p ${shapePath} + - rm -rf ${shapeDevPath} + - rm -rf ${shapePath} + - rm -f ${routerImports} + + vendor: + action: exec:run + TimeoutMs: 120000 + checkError: true + commands: + - cd ${v1Path} + - /tmp/datly transcribe -u dev/vendor -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendor_list.dql + - /tmp/datly transcribe -u dev/vendor/details -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendor_details.dql + - /tmp/datly transcribe -u dev/vendor/col -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendor_col_in.dql + - /tmp/datly transcribe -u dev/vendor/auth -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendor_auth.dql + - /tmp/datly transcribe -u dev/vendor/secured -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendor_apikey.dql + - /tmp/datly transcribe -u dev/vendor/update -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/product_update.dql + - /tmp/datly transcribe -u dev/vendor/meta -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendor_meta.dql + - /tmp/datly transcribe -u dev/vendor/meta-nested -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/child_meta.dql + - /tmp/datly transcribe -u dev/vendor/meta-format -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/meta_format.dql + - /tmp/datly transcribe -u dev/vendor/header -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/header_vendors.dql + - /tmp/datly transcribe -u dev/vendor/env -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/const.dql + - /tmp/datly transcribe -u dev/vendor/variables -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vars.dql + - /tmp/datly transcribe -u dev/vendor/user_acl -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/user_acl.dql + - /tmp/datly transcribe -u dev/vendor/component_acl -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/component_acl/vendor_acl.dql + - /tmp/datly transcribe -u dev/vendor/grouping -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/vendorsrv/vendors_grouping.dql + - /tmp/datly transcribe -u dev/district/meta -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/district/district_pagination.sql + - /tmp/datly transcribe -u dev/user/metadata -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/user/user_metadata.dql + - /tmp/datly transcribe -u dev/user -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/user/user_tree.sql + - /tmp/datly transcribe -u dev/basic/events-one-one -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/events/basic_one_one.dql + - /tmp/datly transcribe -u dev/basic/foos -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/events/patch_basic_one.dql + - /tmp/datly transcribe -u dev/basic/foos-many -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/events/patch_basic_many.dql + - /tmp/datly transcribe -u dev/team/user_team -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/team/user_team.dql + - /tmp/datly transcribe -u dev/team/delete -m ${shapePath} -r ${repoPath} $conn $jwt $api $mode -s dql/dev/team/team.dql diff --git a/e2e/v1/system.yaml b/e2e/v1/system.yaml new file mode 100644 index 000000000..ad0ee43eb --- /dev/null +++ b/e2e/v1/system.yaml @@ -0,0 +1,32 @@ +init: + mysqlSecrets: ${secrets.$mysqlCredentials} +pipeline: + + stop: + services: + action: docker:stop + images: + - mysql + - aerospike + + start: + services: + mysql_dev: + action: docker:run + image: mysql:5.7 + platform: linux/amd64 + name: mysql_dev + ports: + 3306: 3306 + env: + MYSQL_ROOT_PASSWORD: ${mysqlSecrets.Password} + + aerospike: + action: docker:run + platform: linux/amd64 + image: 'aerospike:ce-6.2.0.2' + name: aero + ports: + 3000: 3000 + 3001: 3001 + 3002: 3002 diff --git a/gateway/config.go b/gateway/config.go index 9aedccb0e..58f8d2726 100644 --- a/gateway/config.go +++ b/gateway/config.go @@ -29,6 +29,7 @@ type ( ExposableConfig struct { APIPrefix string //like /v1/api/ RouteURL string + GoBootstrap *GoBootstrap DQLBootstrap *DQLBootstrap ContentURL string PluginsURL string @@ -77,6 +78,11 @@ type ( DQLPathMarker string RoutesRelativePath string } + + GoBootstrap struct { + Packages []string + Exclude []string + } ) const ( @@ -101,7 +107,7 @@ func (c *Config) Validate() error { if c.DQLBootstrap != nil && len(c.DQLBootstrap.Sources) == 0 { return fmt.Errorf("DQLBootstrap.Sources was empty") } - if c.RouteURL == "" && !c.hasDQLBootstrap() { + if c.RouteURL == "" && !c.hasDQLBootstrap() && !c.hasGoBootstrap() { return fmt.Errorf("RouteURL was empty") } return nil @@ -111,6 +117,10 @@ func (c *Config) hasDQLBootstrap() bool { return c != nil && c.DQLBootstrap != nil && len(c.DQLBootstrap.Sources) > 0 } +func (c *Config) hasGoBootstrap() bool { + return c != nil && c.GoBootstrap != nil && len(c.GoBootstrap.Packages) > 0 +} + func (d *DQLBootstrap) ShouldFailFast() bool { if d == nil || d.FailFast == nil { return true diff --git a/gateway/dql_bootstrap.go b/gateway/dql_bootstrap.go index fe7d62dbc..fc5239f5d 100644 --- a/gateway/dql_bootstrap.go +++ b/gateway/dql_bootstrap.go @@ -2,11 +2,11 @@ package gateway import ( "context" - "encoding/json" "fmt" "os" "path" "path/filepath" + "reflect" "sort" "strings" @@ -17,6 +17,8 @@ import ( shapeLoad "github.com/viant/datly/repository/shape/load" datlyservice "github.com/viant/datly/service" "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" ) func (r *Service) applyDQLBootstrap(ctx context.Context, repo *repository.Service, cfg *DQLBootstrap) error { @@ -72,7 +74,7 @@ func (r *Service) applyDQLBootstrap(ctx context.Context, repo *repository.Servic return nil } -func compileBootstrapComponent(ctx context.Context, compiler *shapeCompile.DQLCompiler, loader *shapeLoad.Loader, repo *repository.Service, sourcePath string, cfg *DQLBootstrap, apiPrefix string) (*repository.Component, error) { +func compileBootstrapComponent(ctx context.Context, compiler *shapeCompile.DQLCompiler, loader *shapeLoad.Loader, repo *repository.Service, sourcePath string, cfg *DQLBootstrap, _ string) (*repository.Component, error) { data, err := os.ReadFile(sourcePath) if err != nil { return nil, fmt.Errorf("failed to read DQL bootstrap source %s: %w", sourcePath, err) @@ -101,22 +103,63 @@ func compileBootstrapComponent(ctx context.Context, compiler *shapeCompile.DQLCo if !ok || loaded == nil { return nil, fmt.Errorf("unexpected shape component artifact for %s", sourcePath) } + bootstrapMetadata := snapshotBootstrapViewMetadata(componentArtifact.Resource) rootView := lookupRootView(componentArtifact.Resource, loaded.RootView) if rootView == nil { return nil, fmt.Errorf("missing root view %q for %s", loaded.RootView, sourcePath) } - method, uri := resolvePathSettings(sourcePath, dql, apiPrefix) + method := strings.TrimSpace(strings.ToUpper(loaded.Method)) + uri := strings.TrimSpace(loaded.URI) + if method == "" && len(loaded.ComponentRoutes) > 0 && loaded.ComponentRoutes[0] != nil { + method = strings.TrimSpace(strings.ToUpper(loaded.ComponentRoutes[0].Method)) + } + if uri == "" && len(loaded.ComponentRoutes) > 0 && loaded.ComponentRoutes[0] != nil { + uri = strings.TrimSpace(loaded.ComponentRoutes[0].RoutePath) + } + if method == "" { + method = "GET" + } + if uri == "" { + return nil, fmt.Errorf("missing shape component route for %s", sourcePath) + } + var outputType reflect.Type + if shouldMaterializeBootstrapOutputType(loaded, rootView) { + pkgPath := bootstrapTypePackage(loaded) + lookupType := componentArtifact.Resource.LookupType() + outputType, err = loaded.OutputReflectType(pkgPath, lookupType) + if err != nil { + return nil, fmt.Errorf("failed to materialize bootstrap output type for %s: %w", sourcePath, err) + } + } componentModel := &repository.Component{ Path: contract.Path{ Method: method, URI: uri, }, Contract: contract.Contract{ + Input: contract.Input{ + Type: state.Type{ + Parameters: loaded.InputParameters(), + }, + }, + Output: contract.Output{ + CaseFormat: bootstrapOutputCaseFormat(loaded), + Cardinality: bootstrapOutputCardinality(loaded, rootView), + Type: state.Type{ + Parameters: loaded.OutputParameters(), + }, + }, Service: defaultServiceForMethod(method, rootView), }, View: rootView, TypeContext: loaded.TypeContext, } + if outputType != nil { + if componentModel.Contract.Output.Type.Schema == nil { + componentModel.Contract.Output.Type.Schema = state.NewSchema(nil) + } + componentModel.Contract.Output.Type.SetType(outputType) + } loadOptions := []repository.Option{} if repo != nil { loadOptions = append(loadOptions, repository.WithResources(repo.Resources())) @@ -129,15 +172,43 @@ func compileBootstrapComponent(ctx context.Context, compiler *shapeCompile.DQLCo if err != nil { return nil, fmt.Errorf("failed to materialize bootstrap component for %s: %w", sourcePath, err) } + mergeBootstrapViewMetadata(components.Resource, bootstrapMetadata) if err = components.Init(ctx); err != nil { return nil, fmt.Errorf("failed to initialize bootstrap component for %s: %w", sourcePath, err) } if len(components.Components) == 0 || components.Components[0] == nil { return nil, fmt.Errorf("empty initialized bootstrap component for %s", sourcePath) } + mergeBootstrapView(components.Components[0].View, lookupRootView(bootstrapMetadata, loaded.RootView)) return components.Components[0], nil } +func bootstrapTypePackage(component *shapeLoad.Component) string { + if component == nil || component.TypeContext == nil { + return "" + } + if pkgPath := strings.TrimSpace(component.TypeContext.PackagePath); pkgPath != "" { + return pkgPath + } + return strings.TrimSpace(component.TypeContext.DefaultPackage) +} + +func shouldMaterializeBootstrapOutputType(component *shapeLoad.Component, rootView *view.View) bool { + if component == nil || rootView == nil || rootView.Schema == nil || rootView.Schema.Cardinality != state.One { + return false + } + for _, item := range component.Output { + if item == nil || item.In == nil || item.In.Kind != state.KindOutput || item.In.Name != "view" { + continue + } + if !strings.Contains(item.Tag, "anonymous") || item.Schema == nil { + return false + } + return item.Schema.Cardinality == state.One + } + return false +} + func mergeBootstrapSharedResources(target *view.Resource, repo *repository.Service) { if target == nil || repo == nil || repo.Resources() == nil { return @@ -180,7 +251,7 @@ func hasRepositoryProvider(ctx context.Context, repo *repository.Service, path * _, err := repo.Registry().LookupProvider(ctx, path) if err != nil { message := strings.ToLower(strings.TrimSpace(err.Error())) - if strings.Contains(message, "not found") { + if strings.Contains(message, "not found") || strings.Contains(message, "couldn't match uri") { return false, nil } return false, err @@ -410,44 +481,128 @@ func lookupRootView(resource *view.Resource, root string) *view.View { return nil } -type bootstrapRuleSettings struct { - Method string `json:"Method"` - URI string `json:"URI"` +func bootstrapOutputCardinality(component *shapeLoad.Component, rootView *view.View) state.Cardinality { + if component != nil { + if output := component.OutputParameters(); len(output) > 0 { + if parameter := output.LookupByLocation(state.KindOutput, "view"); parameter != nil && parameter.Schema != nil && parameter.Schema.Cardinality != "" { + return parameter.Schema.Cardinality + } + } + } + if rootView != nil && rootView.Schema != nil && rootView.Schema.Cardinality != "" { + return rootView.Schema.Cardinality + } + return "" } -func resolvePathSettings(sourcePath, dql, apiPrefix string) (string, string) { - method := "GET" - uri := "" - settings := parseBootstrapRuleSettings(dql) - if settings != nil { - if candidate := strings.TrimSpace(strings.ToUpper(settings.Method)); candidate != "" { - method = candidate +func bootstrapOutputCaseFormat(component *shapeLoad.Component) text.CaseFormat { + if component != nil && component.Directives != nil { + if value := strings.TrimSpace(component.Directives.CaseFormat); value != "" { + return text.CaseFormat(value) } - uri = strings.TrimSpace(settings.URI) } - if uri == "" { - stem := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) - uri = "/" + strings.Trim(stem, "/") - if prefix := strings.TrimSpace(apiPrefix); prefix != "" { - uri = strings.TrimRight(prefix, "/") + uri + return text.CaseFormatLowerCamel +} + +func mergeBootstrapViewMetadata(target, source *view.Resource) { + if target == nil || source == nil { + return + } + sourceViews := source.Views.Index() + for _, candidate := range target.Views { + if candidate == nil { + continue + } + original, _ := sourceViews.Lookup(candidate.Name) + if original == nil { + continue } + mergeBootstrapView(candidate, original) } - return method, uri } -func parseBootstrapRuleSettings(dql string) *bootstrapRuleSettings { - start := strings.Index(dql, "/*") - end := strings.Index(dql, "*/") - if start == -1 || end == -1 || end <= start+2 { - return nil +func mergeBootstrapView(target, source *view.View) { + if target == nil || source == nil { + return } - raw := strings.TrimSpace(dql[start+2 : end]) - if !strings.HasPrefix(raw, "{") || !strings.HasSuffix(raw, "}") { - return nil + if source.AllowNulls != nil { + value := *source.AllowNulls + target.AllowNulls = &value + } + if source.Groupable { + target.Groupable = true + } + if source.Selector != nil { + target.Selector = source.Selector } - ret := &bootstrapRuleSettings{} - if err := json.Unmarshal([]byte(raw), ret); err != nil { + if len(source.ColumnsConfig) > 0 { + target.ColumnsConfig = map[string]*view.ColumnConfig{} + for key, cfg := range source.ColumnsConfig { + if cfg == nil { + continue + } + cloned := *cfg + if cfg.DataType != nil { + value := *cfg.DataType + cloned.DataType = &value + } + if cfg.Tag != nil { + value := *cfg.Tag + cloned.Tag = &value + } + if cfg.Groupable != nil { + value := *cfg.Groupable + cloned.Groupable = &value + } + target.ColumnsConfig[key] = &cloned + } + } +} + +func snapshotBootstrapViewMetadata(resource *view.Resource) *view.Resource { + if resource == nil { return nil } - return ret + result := &view.Resource{} + for _, item := range resource.Views { + if item == nil { + continue + } + cloned := &view.View{ + Name: item.Name, + Groupable: item.Groupable, + } + cloned.Reference.Ref = item.Ref + if item.AllowNulls != nil { + value := *item.AllowNulls + cloned.AllowNulls = &value + } + if item.Selector != nil { + cloned.Selector = item.Selector + } + if len(item.ColumnsConfig) > 0 { + cloned.ColumnsConfig = map[string]*view.ColumnConfig{} + for key, cfg := range item.ColumnsConfig { + if cfg == nil { + continue + } + copied := *cfg + if cfg.DataType != nil { + value := *cfg.DataType + copied.DataType = &value + } + if cfg.Tag != nil { + value := *cfg.Tag + copied.Tag = &value + } + if cfg.Groupable != nil { + value := *cfg.Groupable + copied.Groupable = &value + } + cloned.ColumnsConfig[key] = &copied + } + } + result.Views = append(result.Views, cloned) + } + return result } diff --git a/gateway/dql_bootstrap_test.go b/gateway/dql_bootstrap_test.go index b36714cd0..87a56eb4f 100644 --- a/gateway/dql_bootstrap_test.go +++ b/gateway/dql_bootstrap_test.go @@ -2,15 +2,27 @@ package gateway import ( "context" + "net/http" "os" "path/filepath" + "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + marshalconfig "github.com/viant/datly/gateway/router/marshal/config" + marshaljson "github.com/viant/datly/gateway/router/marshal/json" "github.com/viant/datly/repository" "github.com/viant/datly/repository/contract" + shape "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + operator2 "github.com/viant/datly/service/operator" + "github.com/viant/datly/service/session" "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/datly/view/state/kind/locator" + "github.com/viant/tagly/format/text" ) func TestConfigValidate_AllowsEmptyRouteURLWithDQLBootstrap(t *testing.T) { @@ -55,14 +67,32 @@ func TestDiscoverDQLBootstrapSources(t *testing.T) { assert.Contains(t, sources, filepath.Join(root, "sql", "nested", "b.sql")) } -func TestResolvePathSettings(t *testing.T) { - method, uri := resolvePathSettings("/tmp/orders/get.dql", `/* {"Method":"POST","URI":"/v1/api/orders"} */ SELECT 1`, "/v1/api") - assert.Equal(t, "POST", method) - assert.Equal(t, "/v1/api/orders", uri) +func TestCompileBootstrapComponent_UsesShapeRouteMetadata(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "test_conn", + Driver: "sqlite3", + DSN: ":memory:", + }, + }, + }) - method, uri = resolvePathSettings("/tmp/orders/get.dql", `SELECT 1`, "/v1/api") - assert.Equal(t, "GET", method) - assert.Equal(t, "/v1/api/get", uri) + root := t.TempDir() + source := filepath.Join(root, "orders.dql") + dql := "#setting($_ = $connector('test_conn'))\n#setting($_ = $route('/v1/api/orders', 'POST'))\nSELECT 1 AS id" + require.NoError(t, os.WriteFile(source, []byte(dql), 0o644)) + + component, err := compileBootstrapComponent(ctx, shapeCompile.New(), shapeLoad.New(), repo, source, &DQLBootstrap{}, "/v1/api") + require.NoError(t, err) + require.NotNil(t, component) + assert.Equal(t, "POST", component.Method) + assert.Equal(t, "/v1/api/orders", component.URI) } func TestDQLBootstrapEffectivePrecedence(t *testing.T) { @@ -92,7 +122,7 @@ func TestApplyDQLBootstrap_Precedence(t *testing.T) { root := t.TempDir() source := filepath.Join(root, "test.dql") - require.NoError(t, os.WriteFile(source, []byte(`/* {"Method":"GET","URI":"/v1/api/test","Connector":"test_conn"} */ SELECT 1 AS id`), 0o644)) + require.NoError(t, os.WriteFile(source, []byte("#setting($_ = $connector('test_conn'))\n#setting($_ = $route('/v1/api/test', 'GET'))\nSELECT 1 AS id"), 0o644)) srv := &Service{Config: &Config{ExposableConfig: ExposableConfig{APIPrefix: "/v1/api"}}} routesWins := &DQLBootstrap{ @@ -120,3 +150,274 @@ func TestApplyDQLBootstrap_Precedence(t *testing.T) { require.NotNil(t, component.View) assert.Equal(t, "test", component.View.Name) } + +func TestCompileBootstrapComponent_PreservesShapeIOAndGroupingMetadata(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "dev", + Driver: "mysql", + DSN: "root:dev@tcp(127.0.0.1:3306)/dev?parseTime=true", + }, + }, + }) + + root := t.TempDir() + source := filepath.Join(root, "vendors_grouping.dql") + dql := ` +#setting($_ = $connector('dev')) +#setting($_ = $route('/v1/api/shape/dev/vendors-grouping', 'GET')) +#define($_ = $VendorIDs<[]int>(query/vendorIDs)) +#define($_ = $Fields<[]string>(query/_fields).Optional().QuerySelector('vendor')) +#define($_ = $OrderBy(query/_orderby).Optional().QuerySelector('vendor')) +#define($_ = $Data(output/view).Embed()) +SELECT vendor.*, + groupable(vendor), + allowed_order_by_columns(vendor, 'accountId:ACCOUNT_ID,userCreated:USER_CREATED,totalId:TOTAL_ID,maxId:MAX_ID') +FROM ( + SELECT ACCOUNT_ID, + USER_CREATED, + SUM(ID) AS TOTAL_ID, + MAX(ID) AS MAX_ID + FROM VENDOR t + WHERE t.ID IN ($VendorIDs) + GROUP BY 1, 2 +) vendor` + require.NoError(t, os.WriteFile(source, []byte(dql), 0o644)) + + planResult, err := shapeCompile.New().Compile(ctx, &shape.Source{ + Name: "vendors_grouping", + Path: source, + DQL: dql, + }) + require.NoError(t, err) + artifact, err := shapeLoad.New().LoadComponent(ctx, planResult) + require.NoError(t, err) + loaded, ok := artifact.Component.(*shapeLoad.Component) + require.True(t, ok) + sourceRoot := lookupRootView(artifact.Resource, loaded.RootView) + require.NotNil(t, sourceRoot) + require.NotNil(t, sourceRoot.ColumnsConfig["ACCOUNT_ID"]) + require.NotNil(t, sourceRoot.ColumnsConfig["ACCOUNT_ID"].Groupable) + assert.True(t, *sourceRoot.ColumnsConfig["ACCOUNT_ID"].Groupable) + + component, err := compileBootstrapComponent(ctx, shapeCompile.New(), shapeLoad.New(), repo, source, &DQLBootstrap{}, "/v1/api/shape") + require.NoError(t, err) + require.NotNil(t, component) + require.NotNil(t, component.View) + require.True(t, component.View.Groupable) + require.NotNil(t, component.View.Selector) + require.NotNil(t, component.View.Selector.Constraints) + assert.True(t, component.View.Selector.Constraints.OrderBy) + assert.Equal(t, "ACCOUNT_ID", component.View.Selector.Constraints.OrderByColumn["accountId"]) + assert.Equal(t, "ACCOUNT_ID", component.View.Selector.Constraints.OrderByColumn["accountid"]) + require.NotNil(t, component.View.ColumnsConfig["ACCOUNT_ID"]) + require.NotNil(t, component.View.ColumnsConfig["ACCOUNT_ID"].Groupable) + assert.True(t, *component.View.ColumnsConfig["ACCOUNT_ID"].Groupable) + assert.Equal(t, text.CaseFormatLowerCamel, component.Output.CaseFormat) + + inputVendorIDs := component.Input.Type.Parameters.Lookup("VendorIDs") + require.NotNil(t, inputVendorIDs) + assert.Equal(t, state.KindQuery, inputVendorIDs.In.Kind) + assert.Equal(t, "vendorIDs", inputVendorIDs.In.Name) + + inputFields := component.Input.Type.Parameters.Lookup("Fields") + require.NotNil(t, inputFields) + assert.Equal(t, state.KindQuery, inputFields.In.Kind) + assert.Equal(t, "_fields", inputFields.In.Name) + + outputView := component.Output.Type.Parameters.LookupByLocation(state.KindOutput, "view") + require.NotNil(t, outputView) + assert.Contains(t, outputView.Tag, `anonymous:"true"`) + assert.Equal(t, state.Many, component.Output.Cardinality) +} + +func TestCompileBootstrapComponent_MetaFormatOutputTypeMatchesRootView(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "dev", + Driver: "sqlite3", + DSN: ":memory:", + }, + }, + }) + + source := filepath.Join("..", "e2e", "v1", "dql", "dev", "vendorsrv", "meta_format.dql") + dqlBytes, err := os.ReadFile(source) + require.NoError(t, err) + planResult, err := shapeCompile.New().Compile(ctx, &shape.Source{ + Name: "meta_format", + Path: source, + DQL: string(dqlBytes), + }) + require.NoError(t, err) + artifact, err := shapeLoad.New().LoadComponent(ctx, planResult) + require.NoError(t, err) + loaded, ok := artifact.Component.(*shapeLoad.Component) + require.True(t, ok) + loadedRoot := lookupRootView(artifact.Resource, loaded.RootView) + require.NotNil(t, loadedRoot) + require.NotNil(t, loadedRoot.Schema) + t.Logf("loaded root view schema type: %v", loadedRoot.Schema.Type()) + + component, err := compileBootstrapComponent(ctx, shapeCompile.New(), shapeLoad.New(), repo, source, &DQLBootstrap{}, "/v1/api/shape") + require.NoError(t, err) + require.NotNil(t, component) + require.NotNil(t, component.View) + require.NotNil(t, component.View.Schema) + require.NotNil(t, component.View.Schema.Type()) + + outputView := component.Output.Type.Parameters.LookupByLocation(state.KindOutput, "view") + require.NotNil(t, outputView) + require.NotNil(t, outputView.Schema) + require.NotNil(t, outputView.Schema.Type()) + + rootType := component.View.OutputType() + outputType := outputView.OutputType() + assert.Equal(t, rootType.Kind(), outputType.Kind()) + if rootType.Kind() == reflect.Slice { + assert.Equal(t, rootType.Elem().Kind(), outputType.Elem().Kind()) + } + + outputSummary := component.Output.Type.Parameters.LookupByLocation(state.KindOutput, "summary") + require.NotNil(t, outputSummary) + require.NotNil(t, outputSummary.Schema) + require.NotNil(t, outputSummary.Schema.Type()) + require.NotNil(t, component.View.Template) + require.NotNil(t, component.View.Template.Summary) + require.NotNil(t, component.View.Template.Summary.Schema) + require.NotNil(t, component.View.Template.Summary.Schema.Type()) + assert.Equal(t, component.View.Template.Summary.Schema.Type().String(), outputSummary.Schema.Type().String()) +} + +func TestCompileBootstrapComponent_MetaFormatLiveOutputMarshal(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "dev", + Driver: "mysql", + DSN: "root:dev@tcp(127.0.0.1:3306)/dev?parseTime=true", + }, + }, + }) + + source := filepath.Join("..", "e2e", "v1", "dql", "dev", "vendorsrv", "meta_format.dql") + component, err := compileBootstrapComponent(ctx, shapeCompile.New(), shapeLoad.New(), repo, source, &DQLBootstrap{}, "/v1/api/shape") + require.NoError(t, err) + require.NotNil(t, component.View) + require.NotNil(t, component.View.Schema) + t.Logf("root view schema type: %v", component.View.Schema.Type()) + if outputView := component.Output.Type.Parameters.LookupByLocation(state.KindOutput, "view"); outputView != nil && outputView.Schema != nil { + t.Logf("output/view schema type: %v", outputView.Schema.Type()) + } + if outputSummary := component.Output.Type.Parameters.LookupByLocation(state.KindOutput, "summary"); outputSummary != nil && outputSummary.Schema != nil { + t.Logf("output/summary schema type: %v", outputSummary.Schema.Type()) + } + + svc := operator2.New() + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/v1/api/shape/dev/meta/vendors-format/", nil) + require.NoError(t, err) + sess := session.New(component.View, session.WithComponent(component), session.WithLocatorOptions(locator.WithRequest(req))) + outputValue, err := svc.Operate(ctx, sess, component) + require.NoError(t, err) + require.NotNil(t, outputValue) + t.Logf("output type: %T", outputValue) + + marshaller := marshaljson.New(&marshalconfig.IOConfig{CaseFormat: component.Output.CaseFormat}) + _, err = marshaller.Marshal(outputValue) + require.NoError(t, err) +} + +func TestCompileBootstrapComponent_UserAclMaterializesAnonymousOutputStateType(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "dev", + Driver: "sqlite3", + DSN: ":memory:", + }, + }, + }) + + source := filepath.Join("..", "e2e", "v1", "dql", "dev", "vendorsrv", "user_acl.dql") + component, err := compileBootstrapComponent(ctx, shapeCompile.New(), shapeLoad.New(), repo, source, &DQLBootstrap{}, "/v1/api/shape") + require.NoError(t, err) + require.NotNil(t, component) + require.True(t, component.Output.Type.Type().IsDefined()) + require.NotNil(t, component.Output.Type.Schema) + require.NotNil(t, component.Output.Type.Schema.Type()) + + outputView := component.Output.Type.Parameters.LookupByLocation(state.KindOutput, "view") + require.NotNil(t, outputView) + require.NotNil(t, outputView.Schema) + require.NotNil(t, outputView.Schema.Type()) + assert.Equal(t, reflect.Pointer, outputView.OutputType().Kind()) +} + +func TestCompileBootstrapComponent_PatchBasicOneBodyParameterIsSingular(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "dev", + Driver: "sqlite3", + DSN: ":memory:", + }, + }, + }) + + source := filepath.Join("..", "e2e", "v1", "dql", "dev", "events", "patch_basic_one.dql") + component, err := compileBootstrapComponent(ctx, shapeCompile.New(), shapeLoad.New(), repo, source, &DQLBootstrap{}, "/v1/api/shape") + require.NoError(t, err) + require.NotNil(t, component) + + bodyParams := component.Input.Type.Parameters.FilterByKind(state.KindRequestBody) + require.Len(t, bodyParams, 1) + body := bodyParams[0] + require.NotNil(t, body) + require.NotNil(t, body.Schema) + assert.Equal(t, state.One, body.Schema.Cardinality) + require.NotNil(t, body.Schema.Type()) + assert.NotEqual(t, reflect.Slice, body.Schema.Type().Kind(), "body schema should not remain slice-shaped") + + inputStateType := component.Input.Type.Type() + require.NotNil(t, inputStateType) + inputType := inputStateType.Type() + require.NotNil(t, inputType) + if inputType.Kind() == reflect.Ptr { + inputType = inputType.Elem() + } + field, ok := inputType.FieldByName("Foos") + require.True(t, ok) + assert.NotEqual(t, reflect.Slice, field.Type.Kind(), "input Foos field should not remain slice-shaped") +} diff --git a/gateway/go_bootstrap.go b/gateway/go_bootstrap.go new file mode 100644 index 000000000..5ae30fc18 --- /dev/null +++ b/gateway/go_bootstrap.go @@ -0,0 +1,212 @@ +package gateway + +import ( + "context" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/gorouter" + shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" + shapeScan "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/view/state" +) + +func (r *Service) applyGoBootstrap(ctx context.Context, repo *repository.Service, cfg *GoBootstrap) error { + if cfg == nil || len(cfg.Packages) == 0 { + return nil + } + baseDir, err := locateGoBootstrapBaseDir(r.Config) + if err != nil { + return err + } + routes, err := gorouter.Discover(ctx, baseDir, cfg.Packages, cfg.Exclude) + if err != nil { + return err + } + scanner := shapeScan.New() + planner := shapePlan.New() + loader := shapeLoad.New() + for _, route := range routes { + if route == nil || route.Source == nil { + continue + } + component, err := compileGoBootstrapComponent(ctx, scanner, planner, loader, repo, route) + if err != nil { + return err + } + exists, lookupErr := hasRepositoryProvider(ctx, repo, &component.Path) + if lookupErr != nil { + return lookupErr + } + if exists { + continue + } + repo.Register(component) + } + return nil +} + +func locateGoBootstrapBaseDir(cfg *Config) (string, error) { + if cfg == nil { + return "", fmt.Errorf("go bootstrap config was nil") + } + candidates := []string{cfg.DependencyURL, cfg.RouteURL, cfg.ContentURL} + for _, candidate := range candidates { + base := normalizeBootstrapPath(candidate) + if base == "" { + continue + } + if root := walkToGoMod(base); root != "" { + return root, nil + } + } + if wd, err := os.Getwd(); err == nil { + if root := walkToGoMod(wd); root != "" { + return root, nil + } + } + return "", fmt.Errorf("failed to locate Go bootstrap base dir") +} + +func normalizeBootstrapPath(candidate string) string { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return "" + } + candidate = strings.TrimPrefix(candidate, "file://localhost") + candidate = strings.TrimPrefix(candidate, "file://") + if candidate == "" { + return "" + } + return filepath.Clean(candidate) +} + +func walkToGoMod(base string) string { + base = filepath.Clean(base) + info, err := os.Stat(base) + if err != nil { + return "" + } + if !info.IsDir() { + base = filepath.Dir(base) + } + for { + if _, err := os.Stat(filepath.Join(base, "go.mod")); err == nil { + return base + } + parent := filepath.Dir(base) + if parent == base { + return "" + } + base = parent + } +} + +func compileGoBootstrapComponent(ctx context.Context, scanner *shapeScan.StructScanner, planner *shapePlan.Planner, loader *shapeLoad.Loader, repo *repository.Service, route *gorouter.RouteSource) (*repository.Component, error) { + scanResult, err := scanner.Scan(ctx, route.Source) + if err != nil { + return nil, fmt.Errorf("failed to scan Go bootstrap route %s: %w", route.Name, err) + } + planResult, err := planner.Plan(ctx, scanResult) + if err != nil { + return nil, fmt.Errorf("failed to plan Go bootstrap route %s: %w", route.Name, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult, shape.WithLoadTypeContextPackages(true)) + if err != nil { + return nil, fmt.Errorf("failed to load Go bootstrap route %s: %w", route.Name, err) + } + mergeBootstrapSharedResources(componentArtifact.Resource, repo) + loaded, ok := componentArtifact.Component.(*shapeLoad.Component) + if !ok || loaded == nil { + return nil, fmt.Errorf("unexpected Go bootstrap component artifact for %s", route.Name) + } + return materializeBootstrapComponent(ctx, repo, componentArtifact, loaded, route.Name) +} + +func materializeBootstrapComponent(ctx context.Context, repo *repository.Service, componentArtifact *shape.ComponentArtifact, loaded *shapeLoad.Component, sourceName string) (*repository.Component, error) { + bootstrapMetadata := snapshotBootstrapViewMetadata(componentArtifact.Resource) + rootView := lookupRootView(componentArtifact.Resource, loaded.RootView) + if rootView == nil { + return nil, fmt.Errorf("missing root view %q for %s", loaded.RootView, sourceName) + } + method := strings.TrimSpace(strings.ToUpper(loaded.Method)) + uri := strings.TrimSpace(loaded.URI) + if method == "" && len(loaded.ComponentRoutes) > 0 && loaded.ComponentRoutes[0] != nil { + method = strings.TrimSpace(strings.ToUpper(loaded.ComponentRoutes[0].Method)) + } + if uri == "" && len(loaded.ComponentRoutes) > 0 && loaded.ComponentRoutes[0] != nil { + uri = strings.TrimSpace(loaded.ComponentRoutes[0].RoutePath) + } + if method == "" { + method = "GET" + } + if uri == "" { + return nil, fmt.Errorf("missing shape component route for %s", sourceName) + } + var outputType reflect.Type + if shouldMaterializeBootstrapOutputType(loaded, rootView) { + pkgPath := bootstrapTypePackage(loaded) + lookupType := componentArtifact.Resource.LookupType() + outputType, err = loaded.OutputReflectType(pkgPath, lookupType) + if err != nil { + return nil, fmt.Errorf("failed to materialize bootstrap output type for %s: %w", sourceName, err) + } + } + componentModel := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Input: contract.Input{ + Type: state.Type{ + Parameters: loaded.InputParameters(), + }, + }, + Output: contract.Output{ + CaseFormat: bootstrapOutputCaseFormat(loaded), + Cardinality: bootstrapOutputCardinality(loaded, rootView), + Type: state.Type{ + Parameters: loaded.OutputParameters(), + }, + }, + Service: defaultServiceForMethod(method, rootView), + }, + View: rootView, + TypeContext: loaded.TypeContext, + } + if outputType != nil { + if componentModel.Contract.Output.Type.Schema == nil { + componentModel.Contract.Output.Type.Schema = state.NewSchema(nil) + } + componentModel.Contract.Output.Type.SetType(outputType) + } + loadOptions := []repository.Option{} + if repo != nil { + loadOptions = append(loadOptions, repository.WithResources(repo.Resources())) + loadOptions = append(loadOptions, repository.WithExtensions(repo.Extensions())) + } + components, err := repository.LoadComponentsFromMap(ctx, map[string]any{ + "Resource": componentArtifact.Resource, + "Components": []*repository.Component{componentModel}, + }, loadOptions...) + if err != nil { + return nil, fmt.Errorf("failed to materialize bootstrap component for %s: %w", sourceName, err) + } + mergeBootstrapViewMetadata(components.Resource, bootstrapMetadata) + if err = components.Init(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize bootstrap component for %s: %w", sourceName, err) + } + if len(components.Components) == 0 || components.Components[0] == nil { + return nil, fmt.Errorf("empty initialized bootstrap component for %s", sourceName) + } + mergeBootstrapView(components.Components[0].View, lookupRootView(bootstrapMetadata, loaded.RootView)) + return components.Components[0], nil +} diff --git a/gateway/mcp.go b/gateway/mcp.go index 4bf053020..0ec47596a 100644 --- a/gateway/mcp.go +++ b/gateway/mcp.go @@ -70,8 +70,7 @@ func (r *Router) mcpToolCallHandler(component *repository.Component, aRoute *Rou // 2) Apply parameters to request URL/query/body for _, p := range allParams { - name := strings.Title(p.Name) - value := params.Arguments[name] + value := toolArgumentValue(p, params.Arguments) pType := p.Schema.Type() if pType.Kind() == reflect.Ptr { pType = pType.Elem() @@ -277,6 +276,9 @@ func (r *Router) newToolHTTPRequest(method, URL string, body io.Reader) (*http.R if err != nil { return nil, jsonrpc.NewInvalidRequest(err.Error(), nil) } + if body != nil { + httpRequest.Header.Set("Content-Type", "application/json") + } return httpRequest, nil } @@ -404,6 +406,10 @@ func (r *Router) buildToolInputType(components *repository.Component) reflect.Ty } appendField(name, parameter.Schema.Type(), tag) case state.KindRequestBody: + if parameter.IsAnonymous() { + appendAnonymousBodyFields(&inputFields, uniqueFieldName, parameter.Schema.Type()) + continue + } // If body is a slice, mark optional in schema. var tag reflect.StructTag if parameter.Schema != nil && parameter.Schema.Type().Kind() == reflect.Slice { @@ -445,6 +451,77 @@ func (r *Router) buildToolInputType(components *repository.Component) reflect.Ty return reflect.StructOf(inputFields) } +func toolArgumentValue(parameter *state.Parameter, arguments map[string]interface{}) interface{} { + if parameter == nil { + return nil + } + if parameter.In != nil && parameter.In.Kind == state.KindRequestBody && parameter.IsAnonymous() && parameter.Schema != nil { + return anonymousBodyArgumentValue(arguments, parameter.Schema.Type()) + } + return arguments[strings.Title(parameter.Name)] +} + +func appendAnonymousBodyFields(fields *[]reflect.StructField, unique map[string]bool, bodyType reflect.Type) { + bodyType = indirectType(bodyType) + if bodyType == nil || bodyType.Kind() != reflect.Struct { + return + } + for i := 0; i < bodyType.NumField(); i++ { + field := bodyType.Field(i) + if !field.IsExported() { + continue + } + if unique[field.Name] { + continue + } + unique[field.Name] = true + *fields = append(*fields, field) + } +} + +func anonymousBodyArgumentValue(arguments map[string]interface{}, bodyType reflect.Type) interface{} { + bodyType = indirectType(bodyType) + if bodyType == nil || bodyType.Kind() != reflect.Struct { + return nil + } + payload := map[string]interface{}{} + for i := 0; i < bodyType.NumField(); i++ { + field := bodyType.Field(i) + if !field.IsExported() { + continue + } + value, ok := arguments[field.Name] + if !ok { + value, ok = arguments[jsonFieldName(field)] + } + if !ok { + continue + } + payload[jsonFieldName(field)] = value + } + if len(payload) == 0 { + return nil + } + return payload +} + +func jsonFieldName(field reflect.StructField) string { + if tag := field.Tag.Get("json"); tag != "" { + parts := strings.Split(tag, ",") + if parts[0] != "" && parts[0] != "-" { + return parts[0] + } + } + return strings.ToLower(field.Name[:1]) + field.Name[1:] +} + +func indirectType(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} + func (r *Router) buildTemplateResourceIntegration(item *dpath.Item, aPath *dpath.Path, aRoute *Route, provider *repository.Provider) error { if aPath.Internal { return nil diff --git a/gateway/mcp_report_test.go b/gateway/mcp_report_test.go new file mode 100644 index 000000000..b2e98190b --- /dev/null +++ b/gateway/mcp_report_test.go @@ -0,0 +1,354 @@ +package gateway + +import ( + "context" + "embed" + "encoding/json" + "io" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + dpath "github.com/viant/datly/repository/path" + "github.com/viant/datly/repository/version" + "github.com/viant/datly/view" + "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state" + "github.com/viant/mcp-protocol/authorization" + "github.com/viant/mcp-protocol/schema" + serverproto "github.com/viant/mcp-protocol/server" + "github.com/viant/tagly/format/text" + "github.com/viant/xdatly/codec" + "github.com/viant/xreflect" +) + +type repositoryReportTestResource struct{} + +func (r *repositoryReportTestResource) LookupParameter(name string) (*state.Parameter, error) { + return nil, nil +} +func (r *repositoryReportTestResource) AppendParameter(parameter *state.Parameter) {} +func (r *repositoryReportTestResource) ViewSchema(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *repositoryReportTestResource) ViewSchemaPointer(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *repositoryReportTestResource) LookupType() xreflect.LookupType { return nil } +func (r *repositoryReportTestResource) LoadText(ctx context.Context, URL string) (string, error) { + return "", nil +} +func (r *repositoryReportTestResource) Codecs() *codec.Registry { return codec.New() } +func (r *repositoryReportTestResource) CodecOptions() *codec.Options { return codec.NewOptions(nil) } +func (r *repositoryReportTestResource) ExpandSubstitutes(value string) string { return value } +func (r *repositoryReportTestResource) ReverseSubstitutes(value string) string { return value } +func (r *repositoryReportTestResource) EmbedFS() *embed.FS { return nil } +func (r *repositoryReportTestResource) SetFSEmbedder(embedder *state.FSEmbedder) {} + +func TestRouter_buildToolInputType_FlattensAnonymousBody(t *testing.T) { + bodyType := reflect.StructOf([]reflect.StructField{ + { + Name: "Dimensions", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "AccountId", Type: reflect.TypeOf(false), Tag: `json:"accountId,omitempty" desc:"Account identifier"`}, + }), + Tag: `json:"dimensions,omitempty"`, + }, + { + Name: "Measures", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "TotalId", Type: reflect.TypeOf(false), Tag: `json:"totalId,omitempty" desc:"Total identifier"`}, + }), + Tag: `json:"measures,omitempty"`, + }, + { + Name: "Filters", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "VendorIDs", Type: reflect.TypeOf([]int{}), Tag: `json:"vendorIDs,omitempty" desc:"Vendor IDs"`}, + }), + Tag: `json:"filters,omitempty"`, + }, + {Name: "OrderBy", Type: reflect.TypeOf([]string{}), Tag: `json:"orderBy,omitempty"`}, + }) + bodyParam := state.NewParameter("Report", state.NewBodyLocation(""), state.WithParameterSchema(state.NewSchema(bodyType))) + bodyParam.Tag = `anonymous:"true"` + component := &repository.Component{ + Path: contract.Path{Method: "POST", URI: "/v1/api/dev/vendors-grouping/report"}, + View: &view.View{}, + Contract: contract.Contract{ + Input: contract.Input{ + Type: state.Type{Parameters: state.Parameters{bodyParam}}, + }, + }, + } + + rType := (&Router{}).buildToolInputType(component) + require.Equal(t, reflect.Struct, rType.Kind()) + _, ok := rType.FieldByName("Report") + assert.False(t, ok) + for _, name := range []string{"Dimensions", "Measures", "Filters", "OrderBy"} { + _, ok = rType.FieldByName(name) + assert.True(t, ok, name) + } +} + +func TestAnonymousBodyArgumentValue_UsesJSONFieldNames(t *testing.T) { + bodyType := reflect.StructOf([]reflect.StructField{ + {Name: "Dimensions", Type: reflect.StructOf([]reflect.StructField{{Name: "AccountId", Type: reflect.TypeOf(false), Tag: `json:"accountId,omitempty"`}}), Tag: `json:"dimensions,omitempty"`}, + {Name: "Measures", Type: reflect.StructOf([]reflect.StructField{{Name: "TotalId", Type: reflect.TypeOf(false), Tag: `json:"totalId,omitempty"`}}), Tag: `json:"measures,omitempty"`}, + {Name: "Filters", Type: reflect.StructOf([]reflect.StructField{{Name: "VendorIDs", Type: reflect.TypeOf([]int{}), Tag: `json:"vendorIDs,omitempty"`}}), Tag: `json:"filters,omitempty"`}, + {Name: "OrderBy", Type: reflect.TypeOf([]string{}), Tag: `json:"orderBy,omitempty"`}, + {Name: "Limit", Type: reflect.TypeOf((*int)(nil)), Tag: `json:"limit,omitempty"`}, + }) + + value := anonymousBodyArgumentValue(map[string]interface{}{ + "Dimensions": map[string]interface{}{"AccountId": true}, + "Measures": map[string]interface{}{"TotalId": true}, + "Filters": map[string]interface{}{"VendorIDs": []interface{}{1.0, 2.0, 3.0}}, + "OrderBy": []interface{}{"accountId"}, + }, bodyType) + + data, err := json.Marshal(value) + require.NoError(t, err) + assert.JSONEq(t, `{ + "dimensions":{"AccountId":true}, + "measures":{"TotalId":true}, + "filters":{"VendorIDs":[1,2,3]}, + "orderBy":["accountId"] + }`, string(data)) +} + +func TestAnonymousBodyArgumentValue_AcceptsJSONStyleTopLevelArgumentNames(t *testing.T) { + bodyType := reflect.StructOf([]reflect.StructField{ + {Name: "Dimensions", Type: reflect.StructOf([]reflect.StructField{{Name: "AdOrderId", Type: reflect.TypeOf(false), Tag: `json:"adOrderId,omitempty"`}}), Tag: `json:"dimensions,omitempty"`}, + {Name: "Measures", Type: reflect.StructOf([]reflect.StructField{{Name: "Bids", Type: reflect.TypeOf(false), Tag: `json:"bids,omitempty"`}}), Tag: `json:"measures,omitempty"`}, + }) + + value := anonymousBodyArgumentValue(map[string]interface{}{ + "dimensions": map[string]interface{}{"adOrderId": true}, + "measures": map[string]interface{}{"bids": true}, + }, bodyType) + + data, err := json.Marshal(value) + require.NoError(t, err) + assert.JSONEq(t, `{ + "dimensions":{"adOrderId":true}, + "measures":{"bids":true} + }`, string(data)) +} + +func TestRouter_addAuthTokenIfPresent_AddsBearerToken(t *testing.T) { + router := &Router{} + req, err := http.NewRequest(http.MethodPost, "http://localhost/v1/api/dev/vendors-grouping/report", nil) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), authorization.TokenKey, &authorization.Token{Token: "abc123"}) + router.addAuthTokenIfPresent(ctx, req) + + assert.Equal(t, "Bearer abc123", req.Header.Get("Authorization")) +} + +func TestRouter_mcpToolCallHandler_PassesAuthorizationToReportRoute(t *testing.T) { + bodyType := reflect.StructOf([]reflect.StructField{ + { + Name: "Dimensions", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "AccountId", Type: reflect.TypeOf(false), Tag: `json:"accountId,omitempty"`}, + }), + Tag: `json:"dimensions,omitempty"`, + }, + { + Name: "Measures", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "TotalId", Type: reflect.TypeOf(false), Tag: `json:"totalId,omitempty"`}, + }), + Tag: `json:"measures,omitempty"`, + }, + { + Name: "Filters", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "VendorIDs", Type: reflect.TypeOf([]int{}), Tag: `json:"vendorIDs,omitempty"`}, + }), + Tag: `json:"filters,omitempty"`, + }, + {Name: "OrderBy", Type: reflect.TypeOf([]string{}), Tag: `json:"orderBy,omitempty"`}, + }) + bodyParam := state.NewParameter("Report", state.NewBodyLocation(""), state.WithParameterSchema(state.NewSchema(bodyType))) + bodyParam.Tag = `anonymous:"true"` + component := &repository.Component{ + Path: contract.Path{Method: http.MethodPost, URI: "/v1/api/dev/vendors-grouping/report"}, + Contract: contract.Contract{ + Input: contract.Input{ + Type: state.Type{Parameters: state.Parameters{bodyParam}}, + }, + }, + } + + var actualAuth string + var actualBody string + route := &Route{ + Path: &contract.Path{Method: http.MethodPost, URI: "/v1/api/dev/vendors-grouping/report"}, + Handler: func(ctx context.Context, response http.ResponseWriter, req *http.Request) { + actualAuth = req.Header.Get("Authorization") + if req.Body != nil { + payload, _ := io.ReadAll(req.Body) + actualBody = string(payload) + } + response.WriteHeader(http.StatusOK) + _, _ = response.Write([]byte(`{"ok":true}`)) + }, + } + + handler := (&Router{}).mcpToolCallHandler(component, route) + ctx := context.WithValue(context.Background(), authorization.TokenKey, &authorization.Token{Token: "jwt-token"}) + result, rpcErr := handler(ctx, &schema.CallToolRequest{ + Params: schema.CallToolRequestParams{ + Arguments: map[string]interface{}{ + "Dimensions": map[string]interface{}{"AccountId": true}, + "Measures": map[string]interface{}{"TotalId": true}, + "Filters": map[string]interface{}{"VendorIDs": []interface{}{1.0, 2.0}}, + "OrderBy": []interface{}{"accountId"}, + }, + }, + }) + + require.Nil(t, rpcErr) + require.NotNil(t, result) + assert.Equal(t, "Bearer jwt-token", actualAuth) + assert.JSONEq(t, `{ + "dimensions":{"AccountId":true}, + "measures":{"TotalId":true}, + "filters":{"VendorIDs":[1,2]}, + "orderBy":["accountId"] + }`, actualBody) +} + +func TestRouter_newToolHTTPRequest_SetsJSONContentTypeForBody(t *testing.T) { + req, rpcErr := (&Router{}).newToolHTTPRequest(http.MethodPost, "http://localhost/v1/api/dev/vendors-grouping/report", strings.NewReader(`{"dimensions":{"accountId":true}}`)) + require.Nil(t, rpcErr) + require.NotNil(t, req) + assert.Equal(t, "application/json", req.Header.Get("Content-Type")) +} + +func TestRouter_buildToolsIntegration_RegistersReportTool(t *testing.T) { + bodyType := reflect.StructOf([]reflect.StructField{ + { + Name: "Dimensions", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "AccountId", Type: reflect.TypeOf(false), Tag: `json:"accountId,omitempty" desc:"Account identifier"`}, + }), + Tag: `json:"dimensions,omitempty" desc:"Selected grouping dimensions"`, + }, + { + Name: "Measures", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "TotalId", Type: reflect.TypeOf(false), Tag: `json:"totalId,omitempty" desc:"Total identifier"`}, + }), + Tag: `json:"measures,omitempty" desc:"Selected aggregate measures"`, + }, + { + Name: "Filters", + Type: reflect.StructOf([]reflect.StructField{ + {Name: "VendorIDs", Type: reflect.TypeOf([]int{}), Tag: `json:"vendorIDs,omitempty" desc:"Vendor IDs to include"`}, + }), + Tag: `json:"filters,omitempty" desc:"Report filters derived from original predicate parameters"`, + }, + {Name: "OrderBy", Type: reflect.TypeOf([]string{}), Tag: `json:"orderBy,omitempty"`}, + }) + bodyParam := state.NewParameter("Report", state.NewBodyLocation(""), state.WithParameterSchema(state.NewSchema(bodyType))) + bodyParam.Tag = `anonymous:"true"` + component := &repository.Component{ + Path: contract.Path{Method: http.MethodPost, URI: "/v1/api/dev/vendors-grouping/report"}, + View: &view.View{Name: "vendor"}, + Contract: contract.Contract{ + Input: contract.Input{ + Type: state.Type{Parameters: state.Parameters{bodyParam}}, + }, + }, + } + provider := repository.NewProvider( + contract.Path{Method: http.MethodPost, URI: "/v1/api/dev/vendors-grouping/report"}, + &version.Control{}, + func(ctx context.Context, opts ...repository.Option) (*repository.Component, error) { + return component, nil + }, + ) + route := &Route{ + Path: &contract.Path{Method: http.MethodPost, URI: "/v1/api/dev/vendors-grouping/report"}, + Handler: func(ctx context.Context, response http.ResponseWriter, req *http.Request) { + response.WriteHeader(http.StatusOK) + }, + } + registry := serverproto.NewRegistry() + router := &Router{mcpRegistry: registry} + + err := router.buildToolsIntegration(&dpath.Item{}, &dpath.Path{ + Path: contract.Path{Method: http.MethodPost, URI: "/v1/api/dev/vendors-grouping/report"}, + Meta: contract.Meta{Name: "vendors grouping report", Description: "Vendor grouping report"}, + ModelContextProtocol: contract.ModelContextProtocol{ + MCPTool: true, + }, + View: &dpath.ViewRef{Ref: "vendor"}, + }, route, provider) + require.NoError(t, err) + + tools := registry.ListRegisteredTools() + require.Len(t, tools, 1) + tool := tools[0] + assert.Equal(t, "vendorsgroupingreport", tool.Name) + require.Contains(t, tool.InputSchema.Properties, "dimensions") + require.Contains(t, tool.InputSchema.Properties, "measures") + require.Contains(t, tool.InputSchema.Properties, "filters") +} + +func TestRouter_buildToolInputType_UsesBuiltReportComponentParameters(t *testing.T) { + resource := view.EmptyResource() + rootView := view.NewView("vendor", "VENDOR") + rootView.Groupable = true + rootView.Columns = []*view.Column{ + view.NewColumn("AccountID", "int", reflect.TypeOf(0), false), + view.NewColumn("TotalSpend", "float64", reflect.TypeOf(float64(0)), false), + } + rootView.Columns[0].Groupable = true + rootView.Columns[1].Aggregate = true + for _, column := range rootView.Columns { + require.NoError(t, column.Init(&repositoryReportTestResource{}, text.CaseFormatUndefined, false)) + } + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "vendorIDs", In: state.NewQueryLocation("vendorIDs"), Schema: state.NewSchema(reflect.TypeOf([]int{})), Predicates: []*extension.PredicateConfig{{Name: "ByVendor"}}, Description: "Vendor IDs to include"}, + }), state.WithResource(&repositoryReportTestResource{})) + require.NoError(t, err) + inputType.Name = "VendorInput" + + component := &repository.Component{ + Path: contract.Path{Method: http.MethodGet, URI: "/v1/api/vendors"}, + Meta: contract.Meta{Name: "vendors"}, + View: rootView, + Report: &repository.Report{Enabled: true}, + Contract: contract.Contract{ + Input: contract.Input{Type: *inputType}, + }, + } + + reportComponent, err := repository.BuildReportComponent(nil, component) + require.NoError(t, err) + require.NotNil(t, reportComponent) + require.Len(t, reportComponent.Input.Type.Parameters, 1) + + rType := (&Router{}).buildToolInputType(reportComponent) + require.Equal(t, reflect.Struct, rType.Kind()) + _, ok := rType.FieldByName("Report") + assert.False(t, ok) + for _, name := range []string{"Dimensions", "Measures", "Filters", "OrderBy", "Limit", "Offset"} { + _, ok = rType.FieldByName(name) + assert.True(t, ok, name) + } +} diff --git a/gateway/option.go b/gateway/option.go index 2e77d159d..2b113a008 100644 --- a/gateway/option.go +++ b/gateway/option.go @@ -10,14 +10,15 @@ import ( ) type options struct { - config *Config - initializers []func(config *Config, fs *embed.FS) error - extensions *extension.Registry - metrics *gmetric.Service - repository *repository.Service - statusHandler http.Handler - embedFs *embed.FS - configURL string + config *Config + initializers []func(config *Config, fs *embed.FS) error + extensions *extension.Registry + metrics *gmetric.Service + repository *repository.Service + statusHandler http.Handler + embedFs *embed.FS + configURL string + refreshDisabled bool } func newOptions(ctx context.Context, opts ...Option) (*options, error) { @@ -103,3 +104,9 @@ func WithConfigURL(configURL string) Option { o.configURL = configURL } } + +func WithRefreshDisabled(enabled bool) Option { + return func(o *options) { + o.refreshDisabled = enabled + } +} diff --git a/gateway/patch_basic_one_e2e_test.go b/gateway/patch_basic_one_e2e_test.go new file mode 100644 index 000000000..87d84b933 --- /dev/null +++ b/gateway/patch_basic_one_e2e_test.go @@ -0,0 +1,142 @@ +package gateway + +import ( + "context" + "net/http" + "net/http/httptest" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/service/operator" + readerpkg "github.com/viant/datly/service/reader" + "github.com/viant/datly/service/session" + "github.com/viant/datly/view/state" + "github.com/viant/datly/view/state/kind/locator" +) + +func TestGateway_PatchBasicOne_NoRefresh(t *testing.T) { + root := filepath.Clean(filepath.Join("..", "e2e", "v1", "autogen", "Datly", "config_8081.json")) + svc, err := New(context.Background(), + WithConfigURL(root), + WithRefreshDisabled(true), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = svc.Close() + ResetSingleton() + }) + + req := httptest.NewRequest(http.MethodPatch, "/v1/api/shape/dev/basic/foos", strings.NewReader(`{"ID":4,"Quantity":2500,"Name":"changed - foo 4"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + svc.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, rec.Body.String()) + require.Contains(t, rec.Body.String(), `"id":4`) + require.Contains(t, rec.Body.String(), `"quantity":2500`) + require.Contains(t, rec.Body.String(), `"name":"changed - foo 4"`) +} + +func TestGateway_PatchBasicOne_CurFoosValue(t *testing.T) { + root := filepath.Clean(filepath.Join("..", "e2e", "v1", "autogen", "Datly", "config_8081.json")) + svc, err := New(context.Background(), + WithConfigURL(root), + WithRefreshDisabled(true), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = svc.Close() + ResetSingleton() + }) + + component, err := svc.repository.Registry().Lookup(context.Background(), contract.NewPath(http.MethodPatch, "/v1/api/shape/dev/basic/foos")) + require.NoError(t, err) + resource := component.View.GetResource() + require.NotNil(t, resource) + curFoosView, err := resource.GetViews().Lookup("CurFoos") + require.NoError(t, err) + require.NotNil(t, curFoosView) + + plainDest := reflect.New(curFoosView.Schema.SliceType()).Interface() + err = readerpkg.New().ReadInto(context.Background(), plainDest, curFoosView) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPatch, "/v1/api/shape/dev/basic/foos", strings.NewReader(`{"ID":4,"Quantity":2500,"Name":"changed - foo 4"}`)) + req.Header.Set("Content-Type", "application/json") + + unmarshal := component.UnmarshalFunc(req) + locatorOptions := append(component.LocatorOptions(req, nil, unmarshal)) + locatorOptions = append(locatorOptions, locator.WithLogger(nil)) + aSession := session.New(component.View, + session.WithComponent(component), + session.WithLocatorOptions(locatorOptions...), + session.WithRegistry(svc.repository.Registry()), + session.WithOperate(operator.New().Operate)) + + err = aSession.InitKinds(state.KindComponent, state.KindHeader, state.KindRequestBody, state.KindForm, state.KindQuery) + require.NoError(t, err) + err = aSession.Populate(context.Background()) + require.NoError(t, err) + + param, err := component.View.ParamByName("CurFoos") + require.NoError(t, err) + value, has, err := aSession.LookupValue(context.Background(), param, aSession.Indirect(true)) + require.NoError(t, err) + require.True(t, has) + require.NotNil(t, value) +} + +func TestGateway_PatchBasicOne_CurFoosReadInto_WithAndWithoutResourceState(t *testing.T) { + root := filepath.Clean(filepath.Join("..", "e2e", "v1", "autogen", "Datly", "config_8081.json")) + svc, err := New(context.Background(), + WithConfigURL(root), + WithRefreshDisabled(true), + ) + require.NoError(t, err) + t.Cleanup(func() { + _ = svc.Close() + ResetSingleton() + }) + + component, err := svc.repository.Registry().Lookup(context.Background(), contract.NewPath(http.MethodPatch, "/v1/api/shape/dev/basic/foos")) + require.NoError(t, err) + resource := component.View.GetResource() + require.NotNil(t, resource) + curFoosView, err := resource.GetViews().Lookup("CurFoos") + require.NoError(t, err) + require.NotNil(t, curFoosView) + + req := httptest.NewRequest(http.MethodPatch, "/v1/api/shape/dev/basic/foos", strings.NewReader(`{"ID":4,"Quantity":2500,"Name":"changed - foo 4"}`)) + req.Header.Set("Content-Type", "application/json") + + unmarshal := component.UnmarshalFunc(req) + locatorOptions := append(component.LocatorOptions(req, nil, unmarshal)) + locatorOptions = append(locatorOptions, locator.WithLogger(nil)) + aSession := session.New(component.View, + session.WithComponent(component), + session.WithLocatorOptions(locatorOptions...), + session.WithRegistry(svc.repository.Registry()), + session.WithOperate(operator.New().Operate)) + + err = aSession.InitKinds(state.KindComponent, state.KindHeader, state.KindRequestBody, state.KindForm, state.KindQuery) + require.NoError(t, err) + + err = aSession.SetViewState(context.Background(), curFoosView) + require.NoError(t, err) + sqlQuery, buildErr := readerpkg.NewBuilder().Build(context.Background(), + readerpkg.WithBuilderView(curFoosView), + readerpkg.WithBuilderStatelet(aSession.State().Lookup(curFoosView)), + ) + require.NoError(t, buildErr) + require.NotNil(t, sqlQuery) + t.Logf("curFoos sql=%s args=%#v", sqlQuery.SQL, sqlQuery.Args) + + stateDest := reflect.New(curFoosView.Schema.SliceType()).Interface() + err = readerpkg.New().ReadInto(context.Background(), stateDest, curFoosView, readerpkg.WithResourceState(aSession.State())) + require.Error(t, err) +} diff --git a/gateway/route_struct.go b/gateway/route_struct.go index 7e2bbe291..57e7c9d94 100644 --- a/gateway/route_struct.go +++ b/gateway/route_struct.go @@ -7,6 +7,7 @@ import ( "github.com/viant/xreflect" "net/http" "reflect" + "strings" ) func (r *Router) NewStructRoute(URL string, provider *repository.Provider) *Route { @@ -45,5 +46,10 @@ func (r *Router) generateGoStruct(component *repository.Component) (int, []byte) fieldTag, _ = xreflect.RemoveTag(fieldTag, "sql") *tag = fieldTag })) + structContent = legacyStructFormatting(structContent) return http.StatusOK, []byte(structContent) } + +func legacyStructFormatting(content string) string { + return strings.ReplaceAll(content, ` internal:"true"`, ` internal:"true"`) +} diff --git a/gateway/route_struct_test.go b/gateway/route_struct_test.go new file mode 100644 index 000000000..1d78e7991 --- /dev/null +++ b/gateway/route_struct_test.go @@ -0,0 +1,35 @@ +package gateway + +import ( + "net/http" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestRouterGenerateGoStruct_PreservesLegacyInternalTagSpacing(t *testing.T) { + schemaType := reflect.StructOf([]reflect.StructField{ + { + Name: "VendorId", + Type: reflect.TypeOf((*int)(nil)), + Tag: reflect.StructTag(`sqlx:"VENDOR_ID" internal:"true"`), + }, + }) + schema := &state.Schema{Cardinality: state.Many} + schema.SetType(reflect.SliceOf(schemaType)) + + component := &repository.Component{ + View: &view.View{Schema: schema}, + } + + router := &Router{} + statusCode, content := router.generateGoStruct(component) + + require.Equal(t, http.StatusOK, statusCode) + require.True(t, strings.Contains(string(content), `sqlx:"VENDOR_ID" internal:"true"`), string(content)) +} diff --git a/gateway/service.go b/gateway/service.go index 794c2b2d7..cd49567d5 100644 --- a/gateway/service.go +++ b/gateway/service.go @@ -113,6 +113,7 @@ func New(ctx context.Context, opts ...Option) (*Service, error) { repository.WithFirebaseAuth(aConfig.Firebase), repository.WithDependencyURL(aConfig.DependencyURL), repository.WithRefreshFrequency(aConfig.SyncFrequency()), + repository.WithRefreshDisabled(options.refreshDisabled), repository.WithDispatcher(dispatcher.New), ) if err != nil { @@ -122,6 +123,9 @@ func New(ctx context.Context, opts ...Option) (*Service, error) { if err = (&Service{Config: aConfig}).applyDQLBootstrap(ctx, componentRepository, aConfig.DQLBootstrap); err != nil { return nil, fmt.Errorf("failed to apply DQL bootstrap: %w", err) } + if err = (&Service{Config: aConfig}).applyGoBootstrap(ctx, componentRepository, aConfig.GoBootstrap); err != nil { + return nil, fmt.Errorf("failed to apply Go bootstrap: %w", err) + } var mcpRegistry *serverproto.Registry if aConfig.MCP != nil { diff --git a/go.mod b/go.mod index be6c46a28..1d47c9ced 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module github.com/viant/datly go 1.25.0 +replace github.com/viant/xdatly => ../xdatly + require ( github.com/aerospike/aerospike-client-go v4.5.2+incompatible github.com/aws/aws-lambda-go v1.31.0 @@ -53,7 +55,7 @@ require ( github.com/viant/mcp-protocol v0.11.0 github.com/viant/structology v0.8.0 github.com/viant/tagly v0.3.0 - github.com/viant/x v0.4.0 + github.com/viant/x v0.4.1-0.20260306005005-975ded1e1bef github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5 diff --git a/go.sum b/go.sum index 04ca6ae0b..fb3ce7263 100644 --- a/go.sum +++ b/go.sum @@ -1210,10 +1210,8 @@ github.com/viant/toolbox v0.37.0 h1:+zwSdbQh6I6ZEyxokQJr+1gQKbLEw6erc+Av5dwKtLU= github.com/viant/toolbox v0.37.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/velty v0.4.0 h1:eesQES/vCpcoPbM+gQLUBuLEL2sEO+A6s6lPpl8eKc4= github.com/viant/velty v0.4.0/go.mod h1:Q/UXviI2Nli8WROEpYd/BELMCSvnulQeyNrbPmMiS/Y= -github.com/viant/x v0.4.0 h1:n2xuxQdw4lYtMdi59IAQEZHPioNT9InENGGbapyz+P4= -github.com/viant/x v0.4.0/go.mod h1:1TvsnpZFqI9dYVzIkaSYJyJ/UkfxW7fnk0YFafWXrPg= -github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a h1:7CLO2LjVnFgOwN0FL3Q4y5NrD7DpclS21AiW6tDLIc8= -github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a/go.mod h1:lZKZHhVdCZ3U9TU6GUFxKoGN3dPtqt2HkDYzJPq5CEs= +github.com/viant/x v0.4.1-0.20260306005005-975ded1e1bef h1:KqWKMNloyzEg6nIn1pBK4CDEIcaRRhMrMUJr+k+xcPw= +github.com/viant/x v0.4.1-0.20260306005005-975ded1e1bef/go.mod h1:1TvsnpZFqI9dYVzIkaSYJyJ/UkfxW7fnk0YFafWXrPg= github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 h1:9Yry3PUBDzc4rWacOYvAq/TKrTV0agvMF0gwm2gaoHI= github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259/go.mod h1:fb8YgbVadk8X5ZLz49LWGzWmQlZd7Y/I5wE0ru44bIo= github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5 h1:CrT0HTlQul8FoGN0peylVczAOUEXKVqRAiB35ypRNHY= diff --git a/internal/inference/column.go b/internal/inference/column.go index 503b44227..ab56ef670 100644 --- a/internal/inference/column.go +++ b/internal/inference/column.go @@ -4,17 +4,26 @@ import ( "fmt" "github.com/viant/datly/view" "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/query" + "strconv" + "strings" ) type ColumnParameterNamer func(column *Field) string -func ExtractColumnConfig(column *sqlparser.Column) (*view.ColumnConfig, error) { - if column.Comments == "" { +func ExtractColumnConfig(column *sqlparser.Column, groupable bool) (*view.ColumnConfig, error) { + if column.Comments == "" && !groupable { return nil, nil } columnConfig := &view.ColumnConfig{} - if err := TryUnmarshalHint(column.Comments, columnConfig); err != nil { - return nil, fmt.Errorf("invalid column %v settings: %w, %s", column.Name, err, column.Comments) + if column.Comments != "" { + if err := TryUnmarshalHint(column.Comments, columnConfig); err != nil { + return nil, fmt.Errorf("invalid column %v settings: %w, %s", column.Name, err, column.Comments) + } + } + if groupable && columnConfig.Groupable == nil { + columnConfig.Groupable = &groupable } if columnConfig.DataType != nil { column.Type = *columnConfig.DataType @@ -23,3 +32,80 @@ func ExtractColumnConfig(column *sqlparser.Column) (*view.ColumnConfig, error) { columnConfig.Alias = column.Alias return columnConfig, nil } + +func GroupableColumns(aQuery *query.Select, columns sqlparser.Columns) map[string]bool { + result := make(map[string]bool) + if aQuery == nil || len(aQuery.GroupBy) == 0 || len(columns) == 0 { + return result + } + + index := map[string]*sqlparser.Column{} + for _, column := range columns { + if column == nil { + continue + } + for _, key := range columnGroupableKeys(column) { + index[key] = column + } + } + + for _, item := range aQuery.GroupBy { + for _, column := range groupByColumns(item, columns, index) { + result[column.Identity()] = true + } + } + return result +} + +func groupByColumns(item *query.Item, columns sqlparser.Columns, index map[string]*sqlparser.Column) []*sqlparser.Column { + if item == nil || item.Expr == nil { + return nil + } + + if literal, ok := item.Expr.(*expr.Literal); ok { + if position, err := strconv.Atoi(strings.TrimSpace(literal.Value)); err == nil && position > 0 && position <= len(columns) { + return []*sqlparser.Column{columns[position-1]} + } + } + + key := normalizedGroupableKey(sqlparser.Stringify(item.Expr)) + if key == "" { + return nil + } + if column, ok := index[key]; ok { + return []*sqlparser.Column{column} + } + return nil +} + +func columnGroupableKeys(column *sqlparser.Column) []string { + result := make([]string, 0, 4) + appendKey := func(value string) { + key := normalizedGroupableKey(value) + if key == "" { + return + } + for _, existing := range result { + if existing == key { + return + } + } + result = append(result, key) + } + + appendKey(column.Identity()) + appendKey(column.Name) + if column.Namespace != "" && column.Name != "" { + appendKey(column.Namespace + "." + column.Name) + } + appendKey(column.Expression) + return result +} + +func normalizedGroupableKey(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return strings.ToLower(value) +} diff --git a/internal/inference/tag.go b/internal/inference/tag.go index a385e3d7b..2075a485e 100644 --- a/internal/inference/tag.go +++ b/internal/inference/tag.go @@ -100,6 +100,38 @@ func (t *Tags) buildSqlxTag(source *Spec, field *Field) { tagValue.Append("table=" + source.Table) } field.Tags.Set("sqlx", tagValue) + if _, ok := t.tags["source"]; !ok { + if sourceName := sourceColumnName(column); sourceName != "" { + field.Tags.Set("source", TagValue{sourceName}) + } + } +} + +func sourceColumnName(column *sqlparser.Column) string { + if column == nil { + return "" + } + if column.Alias != "" && column.Name != "" && !strings.EqualFold(column.Alias, column.Name) { + return column.Name + } + expression := strings.TrimSpace(column.Expression) + if expression == "" { + return "" + } + if index := strings.LastIndex(expression, "."); index != -1 { + expression = expression[index+1:] + } + expression = strings.Trim(expression, "` ") + if expression == "" { + return "" + } + if column.Alias != "" && strings.EqualFold(expression, column.Alias) { + return "" + } + if column.Name != "" && strings.EqualFold(expression, column.Name) { + return "" + } + return expression } func (t *Tags) buildJSONTag(field *Field) { diff --git a/internal/inference/tag_test.go b/internal/inference/tag_test.go new file mode 100644 index 000000000..234325e4a --- /dev/null +++ b/internal/inference/tag_test.go @@ -0,0 +1,44 @@ +package inference + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view/state" + "github.com/viant/sqlparser" + "github.com/viant/sqlx/metadata/sink" +) + +func TestSpecBuildType_PreservesSourceTagForAliasedProjection(t *testing.T) { + spec := &Spec{ + Table: "CI_TAXONOMY_DISQUALIFIED", + Columns: sqlparser.Columns{ + &sqlparser.Column{ + Name: "TAXONOMY_ID", + Alias: "TAXONOMY_ID", + Expression: "dq.SEGMENT_ID", + Namespace: "dq", + Type: "string", + }, + &sqlparser.Column{ + Name: "IS_DISQUALIFIED", + Type: "int", + }, + }, + pk: map[string]sink.Key{}, + Fk: map[string]sink.Key{}, + } + + err := spec.BuildType("taxonomy", "DisqualifiedView", state.Many, nil, nil) + require.NoError(t, err) + require.NotNil(t, spec.Type) + require.Len(t, spec.Type.columnFields, 2) + + field := spec.Type.columnFields[0] + require.Equal(t, `sqlx:"TAXONOMY_ID" source:"SEGMENT_ID" validate:"required"`, field.Tag) + + structField := field.StructField(WithStructTag()) + require.Equal(t, "SEGMENT_ID", reflect.StructTag(structField.Tag).Get("source")) + require.Equal(t, "TAXONOMY_ID", reflect.StructTag(structField.Tag).Get("sqlx")) +} diff --git a/internal/translator/function/groupable.go b/internal/translator/function/groupable.go new file mode 100644 index 000000000..78d07f178 --- /dev/null +++ b/internal/translator/function/groupable.go @@ -0,0 +1,44 @@ +package function + +import ( + "github.com/viant/datly/view" + "github.com/viant/sqlparser" +) + +type groupable struct{} +type groupingEnabled struct { + groupable +} + +func (c *groupable) Apply(args []string, column *sqlparser.Column, resource *view.Resource, aView *view.View) error { + values, err := convertArguments(c, args) + if err != nil { + return err + } + aView.Groupable = values[0].(bool) + return nil +} + +func (c *groupable) Name() string { + return "groupable" +} + +func (c *groupingEnabled) Name() string { + return "grouping_enabled" +} + +func (c *groupable) Description() string { + return "sets view.Groupable flag to enable dynamic group by rewriting for the view" +} + +func (c *groupable) Arguments() []*Argument { + return []*Argument{ + { + Name: "flag", + Description: "enable dynamic group by for the view", + Required: false, + Default: true, + DataType: "bool", + }, + } +} diff --git a/internal/translator/function/groupable_test.go b/internal/translator/function/groupable_test.go new file mode 100644 index 000000000..5730c4781 --- /dev/null +++ b/internal/translator/function/groupable_test.go @@ -0,0 +1,48 @@ +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" +) + +func TestGroupable_Apply(t *testing.T) { + useCases := []struct { + description string + args []string + expected bool + }{ + { + description: "defaults to true when flag omitted", + args: nil, + expected: true, + }, + { + description: "supports explicit false", + args: []string{"false"}, + expected: false, + }, + } + + for _, useCase := range useCases { + t.Run(useCase.description, func(t *testing.T) { + aView := &view.View{} + fn := &groupable{} + + err := fn.Apply(useCase.args, nil, nil, aView) + require.NoError(t, err) + require.Equal(t, useCase.expected, aView.Groupable) + }) + } +} + +func TestGroupingEnabledAlias_Apply(t *testing.T) { + aView := &view.View{} + fn := &groupingEnabled{} + + err := fn.Apply(nil, nil, nil, aView) + require.NoError(t, err) + require.True(t, aView.Groupable) + require.Equal(t, "grouping_enabled", fn.Name()) +} diff --git a/internal/translator/function/init.go b/internal/translator/function/init.go index b0141c3d8..12442bb85 100644 --- a/internal/translator/function/init.go +++ b/internal/translator/function/init.go @@ -8,6 +8,8 @@ func init() { _registry.Register(&allowedOrderByColumns{}) _registry.Register(&cardinality{}) _registry.Register(&allownulls{}) + _registry.Register(&groupable{}) + _registry.Register(&groupingEnabled{}) _registry.Register(&matchStrategy{}) _registry.Register(&batchSize{}) _registry.Register(&partitioner{}) diff --git a/internal/translator/report_runtime_test.go b/internal/translator/report_runtime_test.go new file mode 100644 index 000000000..878826ab9 --- /dev/null +++ b/internal/translator/report_runtime_test.go @@ -0,0 +1,78 @@ +package translator + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/afs" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/gateway" + "github.com/viant/datly/gateway/runtime/standalone" + "github.com/viant/datly/repository" + "github.com/viant/datly/view" +) + +func TestService_persistRouterRule_PreservesReportMetadataOnRouteComponent(t *testing.T) { + routeRoot := t.TempDir() + + repoOptions := &options.Repository{ + RepositoryURL: routeRoot, + APIPrefix: "/v1/api", + } + cfg := &Config{ + repository: repoOptions, + Config: &standalone.Config{ + Config: &gateway.Config{ + ExposableConfig: gateway.ExposableConfig{ + RouteURL: routeRoot, + }, + }, + }, + } + svc := &Service{ + Repository: &Repository{ + fs: afs.New(), + Config: cfg, + }, + fs: afs.New(), + } + + ruleOptions := &options.Rule{ + Project: routeRoot, + ModulePrefix: "dev", + Source: []string{routeRoot + "/vendors_grouping.sql"}, + } + require.NoError(t, os.WriteFile(ruleOptions.Source[0], []byte("SELECT 1"), 0o600)) + require.NoError(t, ruleOptions.Init()) + + resource := NewResource(ruleOptions, repoOptions, nil) + resource.Rule.Root = "vendor" + resource.Rule.Route.URI = "/vendors-grouping" + resource.Rule.Route.Method = "GET" + resource.Rule.Report = &repository.Report{Enabled: true} + resource.Rule.Viewlets.Append(&Viewlet{ + Name: "vendor", + View: &View{ + View: view.View{ + Name: "vendor", + }, + }, + }) + + require.NoError(t, svc.persistRouterRule(context.Background(), resource, "Reader")) + require.NotEmpty(t, svc.Repository.Files) + + var persisted string + for _, candidate := range svc.Repository.Files { + if strings.HasSuffix(candidate.URL, "vendors_grouping.yaml") { + persisted = candidate.Content + break + } + } + require.NotEmpty(t, persisted) + require.Contains(t, persisted, "Report:") + require.Contains(t, persisted, "Enabled: true") +} diff --git a/internal/translator/resource.go b/internal/translator/resource.go index bf63f33ef..974a4e70e 100644 --- a/internal/translator/resource.go +++ b/internal/translator/resource.go @@ -16,6 +16,7 @@ import ( "github.com/viant/datly/internal/msg" "github.com/viant/datly/internal/setter" tparser "github.com/viant/datly/internal/translator/parser" + "github.com/viant/datly/repository" "github.com/viant/datly/repository/content" expand "github.com/viant/datly/service/executor/expand" "github.com/viant/datly/shared" @@ -39,6 +40,7 @@ var ( handlerSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$handler\s*\(([^)]*)\)\s*\)\s*$`) inputSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$input\s*\(([^)]*)\)\s*\)\s*$`) outputSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$output\s*\(([^)]*)\)\s*\)\s*$`) + reportSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$report\s*\(([^)]*)\)\s*\)\s*$`) marshalSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$marshal\s*\(\s*['"]([^'"]+)['"]\s*,\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) unmarshalSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$unmarshal\s*\(\s*['"]([^'"]+)['"]\s*,\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) formatSettingsLineExpr = regexp.MustCompile(`(?im)^\s*#(?:settings|define|set)\s*\(\s*\$_\s*=\s*\$format\s*\(\s*['"]([^'"]+)['"]\s*\)\s*\)\s*$`) @@ -61,6 +63,7 @@ type routeSettingsDirective struct { Format string DateFormat string CaseFormat string + Report *repository.Report } type ( @@ -486,6 +489,9 @@ func (r *Resource) extractRuleSetting(dSQL *string) error { if directive.CaseFormat != "" { r.Rule.Route.Output.CaseFormat = text.CaseFormat(directive.CaseFormat) } + if directive.Report != nil { + r.Rule.Report = directive.Report + } *dSQL = removeSettingsDirectives(*dSQL) } r.Rule.applyShortHands() @@ -574,6 +580,12 @@ func parseSettingsDirectives(dSQL string) (*routeSettingsDirective, bool, error) } ret.OutputType = value } + matches = reportSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) + if len(matches) > 0 { + found = true + last := matches[len(matches)-1] + ret.Report = parseReportSettings(last[1]) + } matches = marshalSettingsLineExpr.FindAllStringSubmatch(dSQL, -1) if len(matches) > 0 { @@ -722,6 +734,7 @@ func removeSettingsDirectives(dSQL string) string { dSQL = handlerSettingsLineExpr.ReplaceAllString(dSQL, "") dSQL = inputSettingsLineExpr.ReplaceAllString(dSQL, "") dSQL = outputSettingsLineExpr.ReplaceAllString(dSQL, "") + dSQL = reportSettingsLineExpr.ReplaceAllString(dSQL, "") dSQL = marshalSettingsLineExpr.ReplaceAllString(dSQL, "") dSQL = unmarshalSettingsLineExpr.ReplaceAllString(dSQL, "") dSQL = formatSettingsLineExpr.ReplaceAllString(dSQL, "") @@ -730,6 +743,41 @@ func removeSettingsDirectives(dSQL string) string { return dSQL } +func parseReportSettings(input string) *repository.Report { + args := parseQuotedArgs(input) + ret := &repository.Report{ + Enabled: true, + Dimensions: "Dimensions", + Measures: "Measures", + Filters: "Filters", + OrderBy: "OrderBy", + Limit: "Limit", + Offset: "Offset", + } + if len(args) > 0 { + ret.Input = args[0] + } + if len(args) > 1 { + ret.Dimensions = args[1] + } + if len(args) > 2 { + ret.Measures = args[2] + } + if len(args) > 3 { + ret.Filters = args[3] + } + if len(args) > 4 { + ret.OrderBy = args[4] + } + if len(args) > 5 { + ret.Limit = args[5] + } + if len(args) > 6 { + ret.Offset = args[6] + } + return ret +} + func removeHashImportDirectives(dSQL string) string { return hashImportLineExpr.ReplaceAllString(dSQL, "") } diff --git a/internal/translator/rule.go b/internal/translator/rule.go index b2cd35c96..6cea3bbdb 100644 --- a/internal/translator/rule.go +++ b/internal/translator/rule.go @@ -9,6 +9,7 @@ import ( "github.com/viant/datly/internal/inference" "github.com/viant/datly/internal/setter" "github.com/viant/datly/internal/translator/parser" + "github.com/viant/datly/repository" "github.com/viant/datly/repository/async" "github.com/viant/datly/repository/content" "github.com/viant/datly/repository/contract" @@ -65,9 +66,10 @@ type ( Include []string `json:",omitempty"` indexNamespaces IsGeneratation bool - XMLUnmarshalType string `json:",omitempty"` - JSONUnmarshalType string `json:",omitempty"` - JSONMarshalType string `json:",omitempty"` + XMLUnmarshalType string `json:",omitempty"` + JSONUnmarshalType string `json:",omitempty"` + JSONMarshalType string `json:",omitempty"` + Report *repository.Report `json:",omitempty" yaml:"Report,omitempty"` OutputParameter *inference.Parameter } @@ -123,18 +125,19 @@ func (r *Rule) DSQLSetting() interface{} { return struct { URI string Method string - Type string `json:",omitempty"` - InputType string `json:",omitempty"` - OutputType string `json:",omitempty"` - MessageBus string `json:",omitempty"` - CompressAboveSize int `json:",omitempty"` - HandlerArgs []string `json:",omitempty"` - DocURL string `json:",omitempty"` - DocURLs []string `json:",omitempty"` - Internal bool `json:",omitempty"` - JSONUnmarshalType string `json:",omitempty"` - JSONMarshalType string `json:",omitempty"` - Connector string `json:",omitempty"` + Type string `json:",omitempty"` + InputType string `json:",omitempty"` + OutputType string `json:",omitempty"` + MessageBus string `json:",omitempty"` + CompressAboveSize int `json:",omitempty"` + HandlerArgs []string `json:",omitempty"` + DocURL string `json:",omitempty"` + DocURLs []string `json:",omitempty"` + Internal bool `json:",omitempty"` + JSONUnmarshalType string `json:",omitempty"` + JSONMarshalType string `json:",omitempty"` + Connector string `json:",omitempty"` + Report *repository.Report `json:",omitempty"` contract.ModelContextProtocol contract.Meta }{ @@ -152,6 +155,7 @@ func (r *Rule) DSQLSetting() interface{} { JSONUnmarshalType: r.JSONUnmarshalType, JSONMarshalType: r.JSONMarshalType, Connector: r.Connector, + Report: r.Report, ModelContextProtocol: r.ModelContextProtocol, Meta: r.Meta, } diff --git a/internal/translator/service.go b/internal/translator/service.go index f383b9a37..7367f6e3b 100644 --- a/internal/translator/service.go +++ b/internal/translator/service.go @@ -332,6 +332,9 @@ func (s *Service) persistRouterRule(ctx context.Context, resource *Resource, ser } route.Component.Meta = resource.Rule.Meta + if resource.Rule.Report != nil { + route.Component.Report = resource.Rule.Report.Clone() + } if route.Component.Meta.DescriptionURI != "" { URL := url.Join(baseRuleURL, route.Component.Meta.DescriptionURI) description, err := s.fs.DownloadWithURL(ctx, URL) diff --git a/internal/translator/view.go b/internal/translator/view.go index fc0096ca8..09732a307 100644 --- a/internal/translator/view.go +++ b/internal/translator/view.go @@ -182,9 +182,13 @@ func (v *View) buildSelector(namespace *Viewlet, rule *Rule) { Offset: true, Projection: true, } - if !v.ParameterDerived { - selector.Constraints.Filterable = []string{"*"} - } + } + setter.SetBoolIfFalse(&selector.Constraints.Criteria, true) + setter.SetBoolIfFalse(&selector.Constraints.Limit, true) + setter.SetBoolIfFalse(&selector.Constraints.Offset, true) + setter.SetBoolIfFalse(&selector.Constraints.Projection, true) + if len(selector.Constraints.Filterable) == 0 && !v.ParameterDerived { + selector.Constraints.Filterable = []string{"*"} } if querySelectors, ok := namespace.Resource.Declarations.QuerySelectors[namespace.Name]; ok { diff --git a/internal/translator/view_selector_test.go b/internal/translator/view_selector_test.go new file mode 100644 index 000000000..4c8065a53 --- /dev/null +++ b/internal/translator/view_selector_test.go @@ -0,0 +1,58 @@ +package translator + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/internal/inference" + tparser "github.com/viant/datly/internal/translator/parser" + "github.com/viant/datly/view" +) + +func TestView_buildSelector_MergesDefaultConstraints(t *testing.T) { + namespace := &Viewlet{ + Name: "vendor", + Resource: &Resource{ + Declarations: &tparser.Declarations{ + QuerySelectors: map[string]inference.State{}, + }, + }, + } + + aView := &View{ + View: view.View{ + Name: "vendor", + Selector: &view.Config{ + Constraints: &view.Constraints{ + OrderBy: true, + OrderByColumn: map[string]string{ + "accountId": "ACCOUNT_ID", + }, + }, + }, + }, + } + + rule := &Rule{ + Root: "vendor", + Viewlets: Viewlets{ + registry: map[string]*Viewlet{ + "vendor": {Name: "vendor", View: aView}, + }, + keys: []string{"vendor"}, + }, + } + + aView.buildSelector(namespace, rule) + + require.NotNil(t, aView.Selector) + require.NotNil(t, aView.Selector.Constraints) + require.Equal(t, 25, aView.Selector.Limit) + require.True(t, aView.Selector.Constraints.Criteria) + require.True(t, aView.Selector.Constraints.Limit) + require.True(t, aView.Selector.Constraints.Offset) + require.True(t, aView.Selector.Constraints.Projection) + require.True(t, aView.Selector.Constraints.OrderBy) + require.Equal(t, "ACCOUNT_ID", aView.Selector.Constraints.OrderByColumn["accountId"]) + require.Equal(t, []string{"*"}, aView.Selector.Constraints.Filterable) +} diff --git a/internal/translator/viewlet.go b/internal/translator/viewlet.go index bd31ecb50..6f914deaf 100644 --- a/internal/translator/viewlet.go +++ b/internal/translator/viewlet.go @@ -217,6 +217,12 @@ func NewViewlet(name, SQL string, join *query.Join, resource *Resource) *Viewlet func (v *Viewlet) discoverTables(ctx context.Context, db *sql.DB, SQL string) (err error) { v.Table, err = inference.NewTable(ctx, db, SQL) + groupableColumns := map[string]bool{} + if v.Table != nil && v.View != nil && v.View.Groupable { + if parsed, parseErr := sqlparser.ParseQuery(inference.TrimParenthesis(SQL)); parseErr == nil { + groupableColumns = inference.GroupableColumns(parsed, v.Table.QueryColumns) + } + } if v.Table != nil { for _, column := range v.Table.QueryColumns { name := column.Alias @@ -224,7 +230,7 @@ func (v *Viewlet) discoverTables(ctx context.Context, db *sql.DB, SQL string) (e name = column.Name } v.Whitelisted = append(v.Whitelisted, strings.ToLower(name)) - columnConfig, err := inference.ExtractColumnConfig(column) + columnConfig, err := inference.ExtractColumnConfig(column, groupableColumns[column.Identity()]) if err != nil { return err } diff --git a/internal/translator/viewlet_groupable_test.go b/internal/translator/viewlet_groupable_test.go new file mode 100644 index 000000000..40287cc30 --- /dev/null +++ b/internal/translator/viewlet_groupable_test.go @@ -0,0 +1,71 @@ +package translator + +import ( + "context" + "database/sql" + "path/filepath" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" +) + +func TestViewlet_discoverTables_GroupableColumnConfig(t *testing.T) { + ctx := context.Background() + dsn := filepath.Join(t.TempDir(), "viewlet_groupable.sqlite") + db, err := sql.Open("sqlite3", dsn) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(ctx, `CREATE TABLE sales (region_id TEXT, total_sales REAL, country_id TEXT)`) + require.NoError(t, err) + + useCases := []struct { + description string + sql string + groupable bool + expect map[string]bool + }{ + { + description: "flags groupable columns from ordinal group by", + sql: `SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY 1, 3`, + groupable: true, + expect: map[string]bool{ + "region_id": true, + "country_id": true, + }, + }, + { + description: "flags groupable columns from alias and name group by", + sql: `SELECT region_id AS region, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY region, country_id`, + groupable: true, + expect: map[string]bool{ + "region": true, + "country_id": true, + }, + }, + { + description: "does not infer groupable columns without explicit view grouping", + sql: `SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY 1, 3`, + groupable: false, + expect: map[string]bool{}, + }, + } + + for _, useCase := range useCases { + t.Run(useCase.description, func(t *testing.T) { + viewlet := NewViewlet("sales", useCase.sql, nil, &Resource{}) + viewlet.View = &View{} + viewlet.View.Groupable = useCase.groupable + err := viewlet.discoverTables(ctx, db, useCase.sql) + require.NoError(t, err) + + actual := map[string]bool{} + for _, config := range viewlet.ColumnConfig { + require.NotNil(t, config.Groupable) + actual[config.Name] = *config.Groupable + } + require.Equal(t, useCase.expect, actual) + }) + } +} diff --git a/repository/component.go b/repository/component.go index 179ff7a9a..1c2c17bb1 100644 --- a/repository/component.go +++ b/repository/component.go @@ -3,6 +3,7 @@ package repository import ( "context" "embed" + stdjson "encoding/json" "fmt" "net/http" "reflect" @@ -48,6 +49,7 @@ type ( View *view.View `json:",omitempty"` NamespacedView *view.NamespacedView Handler *handler.Handler `json:",omitempty"` + Report *Report `json:",omitempty" yaml:"Report,omitempty"` TypeContext *typectx.Context `json:",omitempty" yaml:",omitempty"` indexedView view.NamedViews SourceURL string @@ -440,6 +442,13 @@ func (c *Component) UnmarshalFor(opts ...UnmarshalOption) shared.Unmarshal { } req := options.request // capture for closure + if c != nil && c.Report != nil && c.Report.Enabled && c.Handler != nil { + if parameter := c.Input.Type.AnonymousParameters(); parameter != nil && parameter.In != nil && parameter.In.Kind == state.KindRequestBody { + return func(data []byte, dest interface{}) error { + return stdjson.Unmarshal(data, dest) + } + } + } return func(data []byte, dest interface{}) error { if len(interceptors) > 0 || req != nil { return c.Content.Marshaller.JSON.JsonMarshaller.Unmarshal(data, dest, interceptors, req) @@ -561,6 +570,13 @@ func WithView(aView *view.View) ComponentOption { } } +func WithReport(report *Report) ComponentOption { + return func(c *Component) error { + c.Report = report.Clone() + return nil + } +} + func WithHandler(aHandler xhandler.Handler) ComponentOption { return func(c *Component) error { c.Handler = handler.NewHandler(aHandler) diff --git a/repository/contract/contract.go b/repository/contract/contract.go index 7258d3617..a61204dac 100644 --- a/repository/contract/contract.go +++ b/repository/contract/contract.go @@ -29,10 +29,10 @@ type ( // Types returns all types func (c *Contract) Types() []*state.Type { var types []*state.Type - if c.Input.Type.Type().IsDefined() { + if inputType := c.Input.Type.Type(); inputType != nil && inputType.IsDefined() { types = append(types, &c.Input.Type) } - if c.Output.Type.Type().IsDefined() { + if outputType := c.Output.Type.Type(); outputType != nil && outputType.IsDefined() { types = append(types, &c.Output.Type) } return types diff --git a/repository/handler/handler.go b/repository/handler/handler.go index f8434d71a..a2a91c7ea 100644 --- a/repository/handler/handler.go +++ b/repository/handler/handler.go @@ -128,7 +128,7 @@ func (h *Handler) buildFactoryOptions() ([]handler.Option, error) { func NewHandler(handler handler.Handler) *Handler { rType := reflect.TypeOf(handler) - return &Handler{Type: rType.Name(), _type: rType} + return &Handler{Type: rType.Name(), _type: rType, handler: handler} } func lookupByPackagePathAlias(lookup xreflect.LookupType, typeName string) reflect.Type { diff --git a/repository/option.go b/repository/option.go index 9c2b9b34b..8ac783c14 100644 --- a/repository/option.go +++ b/repository/option.go @@ -45,6 +45,7 @@ type Options struct { authConfig aconfig.Config shapePipeline bool legacyTypeContext bool + refreshDisabled bool } func (o *Options) UseColumn() bool { @@ -195,6 +196,14 @@ func WithRefreshFrequency(refreshFrequency time.Duration) Option { } } +// WithRefreshDisabled suppresses repository change polling and lazy hot-reload. +// Disabled by default to preserve existing behavior. +func WithRefreshDisabled(enabled bool) Option { + return func(o *Options) { + o.refreshDisabled = enabled + } +} + func WithResourceURL(URL string) Option { return func(o *Options) { o.resourceURL = URL diff --git a/repository/option_shape_test.go b/repository/option_shape_test.go index 11bf4ecba..ce232bcf2 100644 --- a/repository/option_shape_test.go +++ b/repository/option_shape_test.go @@ -27,3 +27,14 @@ func TestWithLegacyTypeContext(t *testing.T) { WithLegacyTypeContext(false)(opts) assert.False(t, opts.legacyTypeContext) } + +func TestWithRefreshDisabled(t *testing.T) { + opts := NewOptions(nil) + assert.False(t, opts.refreshDisabled) + + WithRefreshDisabled(true)(opts) + assert.True(t, opts.refreshDisabled) + + WithRefreshDisabled(false)(opts) + assert.False(t, opts.refreshDisabled) +} diff --git a/repository/path/container.go b/repository/path/container.go index 8eac0d373..e8e316bbc 100644 --- a/repository/path/container.go +++ b/repository/path/container.go @@ -37,6 +37,18 @@ type ( With []string } + Report struct { + Enabled bool `json:",omitempty" yaml:"Enabled,omitempty"` + MCPTool *bool `json:",omitempty" yaml:"MCPTool,omitempty"` + Input string `json:",omitempty" yaml:"Input,omitempty"` + Dimensions string `json:",omitempty" yaml:"Dimensions,omitempty"` + Measures string `json:",omitempty" yaml:"Measures,omitempty"` + Filters string `json:",omitempty" yaml:"Filters,omitempty"` + OrderBy string `json:",omitempty" yaml:"OrderBy,omitempty"` + Limit string `json:",omitempty" yaml:"Limit,omitempty"` + Offset string `json:",omitempty" yaml:"Offset,omitempty"` + } + ViewRef struct { Ref string `yaml:"Ref" json:"Ref"` // Ref is the reference to the view definition } @@ -47,6 +59,7 @@ type ( contract.Meta `yaml:",inline"` contract.ModelContextProtocol `yaml:",inline"` Handler *Handler `yaml:"Handler" json:"Handler"` + Report *Report `yaml:"Report,omitempty" json:"Report,omitempty"` Internal bool `json:"Internal,omitempty" yaml:"Internal,omitempty" ` Connector string `json:",omitempty"` ContentURL string `json:"ContentURL,omitempty" yaml:"ContentURL,omitempty" ` diff --git a/repository/path/service.go b/repository/path/service.go index 2ce530d9d..44118dbbd 100644 --- a/repository/path/service.go +++ b/repository/path/service.go @@ -240,8 +240,11 @@ func (s *Service) load(ctx context.Context) error { } func (s *Service) onModify(ctx context.Context, object storage.Object) error { - path := url.Path(object.URL()) - prev := s.lookupRouteBySourceURL(path) + sourceURL := object.URL() + prev := s.lookupRouteBySourceURL(sourceURL) + if prev == nil { + prev = s.lookupRouteBySourceURL(url.Path(sourceURL)) + } if prev != nil && prev.Version.HasChanged(object.ModTime()) { return nil } @@ -263,8 +266,11 @@ func (s *Service) onModify(ctx context.Context, object storage.Object) error { } func (s *Service) onDelete(ctx context.Context, object storage.Object) error { - path := url.Path(object.URL()) - prev := s.lookupRouteBySourceURL(path) + sourceURL := object.URL() + prev := s.lookupRouteBySourceURL(sourceURL) + if prev == nil { + prev = s.lookupRouteBySourceURL(url.Path(sourceURL)) + } if prev == nil { return nil } @@ -272,7 +278,7 @@ func (s *Service) onDelete(ctx context.Context, object storage.Object) error { prev.Version.Increase() // TODO delete works fine but after adding back rule file we get panic - //s.delete(prev, path) + //s.delete(prev, sourceURL) return nil } diff --git a/repository/path/service_test.go b/repository/path/service_test.go index ed1888b3a..d723ada37 100644 --- a/repository/path/service_test.go +++ b/repository/path/service_test.go @@ -4,6 +4,7 @@ import ( "context" _ "embed" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/viant/afs" "github.com/viant/afs/asset" "github.com/viant/afs/file" @@ -64,3 +65,29 @@ func TestNew(t *testing.T) { } } + +func TestService_onModify_DoesNotAppendDuplicateForTrackedURL(t *testing.T) { + location := "mem://localhost/test/routes_modify" + mgr, err := afs.Manager(location) + require.NoError(t, err) + err = asset.Create(mgr, location, []*asset.Resource{ + asset.New("dev/vendor.yml", file.DefaultFileOsMode, false, "", ruleVendor), + }) + require.NoError(t, err) + + service, err := New(context.Background(), afs.New(), location, time.Second) + require.NoError(t, err) + require.Len(t, service.Container.Items, 1) + + fs := afs.New() + object, err := fs.Object(context.Background(), "mem://localhost/test/routes_modify/dev/vendor.yml") + require.NoError(t, err) + + err = service.onModify(context.Background(), object) + require.NoError(t, err) + + require.Len(t, service.Container.Items, 1) + aPath := &contract.Path{URI: "/v1/api/dev/hauth/vendors/{vendorID}", Method: "GET"} + element := service.Lookup(aPath) + require.NotNil(t, element) +} diff --git a/repository/report.go b/repository/report.go new file mode 100644 index 000000000..37baa30cf --- /dev/null +++ b/repository/report.go @@ -0,0 +1,8 @@ +package repository + +import reportmodel "github.com/viant/datly/repository/report" + +type Report = reportmodel.Config +type ReportMetadata = reportmodel.Metadata +type ReportField = reportmodel.Field +type ReportFilter = reportmodel.Filter diff --git a/repository/report/build.go b/repository/report/build.go new file mode 100644 index 000000000..6d054e6a0 --- /dev/null +++ b/repository/report/build.go @@ -0,0 +1,333 @@ +package report + +import ( + "context" + "embed" + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" + "github.com/viant/xdatly/codec" + "github.com/viant/xreflect" +) + +type Component struct { + Name string + InputName string + Parameters state.Parameters + View *view.View + Resource state.Resource + Report *Config +} + +func AssembleMetadata(component *Component, cfg *Config) (*Metadata, error) { + if component == nil { + return nil, fmt.Errorf("report component was empty") + } + cfg = normalizeConfig(component, cfg) + viewRef := component.View + if viewRef == nil { + return nil, fmt.Errorf("report component view was empty") + } + result := &Metadata{ + InputName: cfg.InputTypeName(component.Name, component.InputName, viewRef.Name), + BodyFieldName: "", + DimensionsKey: cfg.Dimensions, + MeasuresKey: cfg.Measures, + FiltersKey: cfg.Filters, + OrderBy: cfg.OrderBy, + Limit: cfg.Limit, + Offset: cfg.Offset, + } + for _, column := range viewRef.Columns { + if column == nil || column.FieldName() == "" { + continue + } + fieldName := exportedFieldName(column.FieldName()) + field := &Field{Name: column.FieldName(), FieldName: fieldName, Description: column.Name} + switch { + case column.Groupable: + field.Section = cfg.Dimensions + result.Dimensions = append(result.Dimensions, field) + case column.Aggregate || (viewRef.Groupable && !column.Groupable): + field.Section = cfg.Measures + result.Measures = append(result.Measures, field) + } + } + for _, parameter := range component.Parameters { + if parameter == nil || len(parameter.Predicates) == 0 || parameter.In == nil { + continue + } + if isSelectorParameter(parameter, viewRef) { + continue + } + result.Filters = append(result.Filters, &Filter{ + Name: parameter.Name, + FieldName: exportedFieldName(parameter.Name), + Section: cfg.Filters, + Description: parameter.Description, + Parameter: parameter, + }) + } + if err := result.ValidateSelection(); err != nil { + return nil, err + } + return result, nil +} + +func BuildBodyType(metadata *Metadata) reflect.Type { + var fields []reflect.StructField + fields = append(fields, reflect.StructField{ + Name: metadata.DimensionsKey, + Type: sectionStructType(metadata.Dimensions), + Tag: buildTag(lowerCamel(metadata.DimensionsKey), "Selected grouping dimensions"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.MeasuresKey, + Type: sectionStructType(metadata.Measures), + Tag: buildTag(lowerCamel(metadata.MeasuresKey), "Selected aggregate measures"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.FiltersKey, + Type: filterStructType(metadata.Filters), + Tag: buildTag(lowerCamel(metadata.FiltersKey), "Report filters derived from original predicate parameters"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.OrderBy, + Type: reflect.TypeOf([]string{}), + Tag: buildTag(lowerCamel(metadata.OrderBy), "Ordering expressions applied to the grouped result"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.Limit, + Type: reflect.TypeOf((*int)(nil)), + Tag: buildTag(lowerCamel(metadata.Limit), "Maximum number of grouped rows to return"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.Offset, + Type: reflect.TypeOf((*int)(nil)), + Tag: buildTag(lowerCamel(metadata.Offset), "Row offset applied to the grouped result"), + }) + return reflect.StructOf(fields) +} + +func BuildInputType(component *Component, metadata *Metadata, cfg *Config) (*state.Type, error) { + if component == nil { + return nil, fmt.Errorf("report component was empty") + } + if metadata == nil { + return nil, fmt.Errorf("report metadata was empty") + } + cfg = normalizeConfig(component, cfg) + if cfg.Input != "" { + schema := state.NewSchema(nil, state.WithSchemaPackage(""), state.WithModulePath("")) + schema.Name = strings.TrimSpace(cfg.Input) + inputType, err := state.NewType(state.WithSchema(schema), state.WithResource(component.resource())) + if err != nil { + return nil, err + } + if err := inputType.Init(); err != nil { + return nil, err + } + return inputType, validateExplicitInput(inputType, metadata) + } + bodyType := reflect.PtrTo(BuildBodyType(metadata)) + bodySchema := state.NewSchema(bodyType) + bodySchema.Name = metadata.InputName + bodyParam := state.NewParameter("Report", state.NewBodyLocation(""), state.WithParameterSchema(bodySchema)) + bodyParam.Tag = `anonymous:"true"` + bodyParam.SetTypeNameTag() + inputType, err := state.NewType( + state.WithParameters(state.Parameters{bodyParam}), + state.WithBodyType(true), + state.WithSchema(state.NewSchema(bodyType)), + state.WithResource(newInputResource(component.resource())), + ) + if err != nil { + return nil, err + } + if err := inputType.Init(); err != nil { + return nil, err + } + inputType.Name = metadata.InputName + return inputType, nil +} + +func normalizeConfig(component *Component, cfg *Config) *Config { + if cfg != nil { + return cfg.Normalize() + } + if component == nil || component.Report == nil { + return (&Config{}).Normalize() + } + return component.Report.Normalize() +} + +func (c *Component) resource() state.Resource { + if c == nil { + return nil + } + if c.Resource != nil { + return c.Resource + } + if c.View != nil { + return c.View.Resource() + } + return nil +} + +func exportedFieldName(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return state.SanitizeTypeName(value) +} + +func isSelectorParameter(parameter *state.Parameter, aView *view.View) bool { + if parameter == nil || parameter.In == nil { + return false + } + if aView != nil && aView.Selector != nil { + for _, selector := range []*state.Parameter{ + aView.Selector.FieldsParameter, + aView.Selector.OrderByParameter, + aView.Selector.LimitParameter, + aView.Selector.OffsetParameter, + aView.Selector.PageParameter, + } { + if selector != nil && selector.In != nil && selector.In.Name == parameter.In.Name { + return true + } + } + } + name := strings.ToLower(parameter.In.Name) + return name == "_fields" || name == "_orderby" || name == "_limit" || name == "_offset" || name == "_page" || name == "criteria" +} + +func validateExplicitInput(inputType *state.Type, metadata *Metadata) error { + if inputType == nil { + return fmt.Errorf("explicit report input type was empty") + } + var rType reflect.Type + if inputType.Schema != nil { + rType = inputType.Schema.Type() + } + if rType == nil && inputType.Type() != nil { + rType = inputType.Type().Type() + } + if rType == nil { + return fmt.Errorf("explicit report input state type was empty") + } + if rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + for _, fieldName := range []string{metadata.DimensionsKey, metadata.MeasuresKey, metadata.FiltersKey, metadata.OrderBy, metadata.Limit, metadata.Offset} { + if fieldName == "" { + continue + } + if _, ok := rType.FieldByName(fieldName); !ok { + return fmt.Errorf("explicit report input %s missing field %s", rType.String(), fieldName) + } + } + return nil +} + +func sectionStructType(fields []*Field) reflect.Type { + if len(fields) == 0 { + return reflect.TypeOf(struct{}{}) + } + structFields := make([]reflect.StructField, 0, len(fields)) + for _, field := range fields { + structFields = append(structFields, reflect.StructField{ + Name: field.FieldName, + Type: reflect.TypeOf(false), + Tag: buildTag(lowerCamel(field.Name), field.Description), + }) + } + return reflect.StructOf(structFields) +} + +func filterStructType(filters []*Filter) reflect.Type { + if len(filters) == 0 { + return reflect.TypeOf(struct{}{}) + } + structFields := make([]reflect.StructField, 0, len(filters)) + for _, filter := range filters { + rType := reflect.TypeOf("") + if schemaType := filter.SchemaType(); schemaType != nil { + rType = schemaType + } + structFields = append(structFields, reflect.StructField{ + Name: filter.FieldName, + Type: rType, + Tag: buildTag(lowerCamel(filter.Name), filter.Description), + }) + } + return reflect.StructOf(structFields) +} + +func buildTag(jsonName, description string) reflect.StructTag { + result := fmt.Sprintf(`json:"%s,omitempty"`, jsonName) + if description = strings.TrimSpace(description); description != "" { + result += " desc:" + strconv.Quote(description) + } + return reflect.StructTag(result) +} + +func lowerCamel(value string) string { + if value == "" { + return "" + } + return text.CaseFormatUpperCamel.Format(value, text.CaseFormatLowerCamel) +} + +type inputResource struct { + base state.Resource +} + +func newInputResource(base state.Resource) state.Resource { + return &inputResource{base: base} +} + +func (r *inputResource) LookupParameter(name string) (*state.Parameter, error) { return nil, nil } +func (r *inputResource) AppendParameter(parameter *state.Parameter) {} +func (r *inputResource) ViewSchema(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *inputResource) ViewSchemaPointer(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *inputResource) LookupType() xreflect.LookupType { return nil } +func (r *inputResource) LoadText(ctx context.Context, URL string) (string, error) { + return "", nil +} +func (r *inputResource) Codecs() *codec.Registry { + if r.base != nil && r.base.Codecs() != nil { + return r.base.Codecs() + } + return codec.New() +} +func (r *inputResource) CodecOptions() *codec.Options { + if r.base != nil && r.base.CodecOptions() != nil { + return r.base.CodecOptions() + } + return codec.NewOptions(nil) +} +func (r *inputResource) ExpandSubstitutes(value string) string { + if r.base != nil { + return r.base.ExpandSubstitutes(value) + } + return value +} +func (r *inputResource) ReverseSubstitutes(value string) string { + if r.base != nil { + return r.base.ReverseSubstitutes(value) + } + return value +} +func (r *inputResource) EmbedFS() *embed.FS { return nil } +func (r *inputResource) SetFSEmbedder(embedder *state.FSEmbedder) {} diff --git a/repository/report/build_test.go b/repository/report/build_test.go new file mode 100644 index 000000000..0551a5412 --- /dev/null +++ b/repository/report/build_test.go @@ -0,0 +1,281 @@ +package report + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" + "github.com/viant/xdatly/codec" + "github.com/viant/xreflect" +) + +type testResource struct{} + +type explicitReportInput struct { + Dimensions struct { + AccountID bool + UserCreated bool + } + Measures struct { + TotalSpend bool + } + Filters struct { + AccountId int + } + OrderBy []string + Limit *int + Offset *int +} + +func (r *testResource) LookupParameter(name string) (*state.Parameter, error) { return nil, nil } +func (r *testResource) AppendParameter(parameter *state.Parameter) {} +func (r *testResource) ViewSchema(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *testResource) ViewSchemaPointer(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *testResource) LookupType() xreflect.LookupType { return nil } +func (r *testResource) LoadText(ctx context.Context, URL string) (string, error) { + return "", nil +} +func (r *testResource) Codecs() *codec.Registry { return codec.New() } +func (r *testResource) CodecOptions() *codec.Options { return codec.NewOptions(nil) } +func (r *testResource) ExpandSubstitutes(value string) string { return value } +func (r *testResource) ReverseSubstitutes(value string) string { return value } +func (r *testResource) EmbedFS() *embed.FS { return nil } +func (r *testResource) SetFSEmbedder(embedder *state.FSEmbedder) { +} + +func TestAssembleMetadata(t *testing.T) { + tests := []struct { + name string + component *Component + config *Config + assertion func(t *testing.T, got *Metadata, err error) + }{ + { + name: "uses component report defaults", + component: newComponentFixture(t, &Config{Enabled: true}), + assertion: func(t *testing.T, got *Metadata, err error) { + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "VendorInputReportInput", got.InputName) + assert.Equal(t, "Dimensions", got.DimensionsKey) + assert.Equal(t, "Measures", got.MeasuresKey) + assert.Equal(t, "Filters", got.FiltersKey) + require.Len(t, got.Dimensions, 2) + require.Len(t, got.Measures, 1) + require.Len(t, got.Filters, 1) + assert.Equal(t, "AccountID", got.Dimensions[0].Name) + assert.Equal(t, "UserCreated", got.Dimensions[1].Name) + assert.Equal(t, "TotalSpend", got.Measures[0].Name) + assert.Equal(t, "accountID", got.Filters[0].Name) + assert.Equal(t, "AccountId", got.Filters[0].FieldName) + }, + }, + { + name: "uses explicit config names", + component: newComponentFixture(t, &Config{Enabled: true}), + config: &Config{ + Input: "CustomReportInput", + Dimensions: "Groups", + Measures: "Metrics", + Filters: "Predicates", + OrderBy: "Sort", + Limit: "PageSize", + Offset: "Cursor", + }, + assertion: func(t *testing.T, got *Metadata, err error) { + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "CustomReportInput", got.InputName) + assert.Equal(t, "Groups", got.DimensionsKey) + assert.Equal(t, "Metrics", got.MeasuresKey) + assert.Equal(t, "Predicates", got.FiltersKey) + assert.Equal(t, "Sort", got.OrderBy) + assert.Equal(t, "PageSize", got.Limit) + assert.Equal(t, "Cursor", got.Offset) + assert.Equal(t, "Groups", got.Dimensions[0].Section) + assert.Equal(t, "Metrics", got.Measures[0].Section) + assert.Equal(t, "Predicates", got.Filters[0].Section) + }, + }, + { + name: "errors on missing view", + component: &Component{Report: &Config{Enabled: true}}, + assertion: func(t *testing.T, got *Metadata, err error) { + require.Error(t, err) + assert.Nil(t, got) + assert.Contains(t, err.Error(), "view was empty") + }, + }, + { + name: "errors when no selectable columns", + component: newComponentWithoutSelectableColumns(t), + assertion: func(t *testing.T, got *Metadata, err error) { + require.Error(t, err) + assert.Nil(t, got) + assert.Contains(t, err.Error(), "no selectable dimensions or measures") + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := AssembleMetadata(test.component, test.config) + test.assertion(t, got, err) + }) + } +} + +func TestBuildInputType(t *testing.T) { + tests := []struct { + name string + component *Component + config *Config + assertion func(t *testing.T, got *state.Type, err error) + }{ + { + name: "builds synthetic anonymous body input", + component: newComponentFixture(t, &Config{Enabled: true}), + assertion: func(t *testing.T, got *state.Type, err error) { + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "VendorInputReportInput", got.Name) + require.Len(t, got.Parameters, 1) + assert.True(t, got.Parameters[0].IsAnonymous()) + require.NotNil(t, got.Schema) + rType := got.Schema.Type() + require.NotNil(t, rType) + assert.Equal(t, reflect.Ptr, rType.Kind()) + bodyType := rType.Elem() + dimensions, ok := bodyType.FieldByName("Dimensions") + require.True(t, ok) + assert.Equal(t, `json:"dimensions,omitempty" desc:"Selected grouping dimensions"`, string(dimensions.Tag)) + measures, ok := bodyType.FieldByName("Measures") + require.True(t, ok) + assert.Equal(t, reflect.Struct, measures.Type.Kind()) + filters, ok := bodyType.FieldByName("Filters") + require.True(t, ok) + filterField, ok := filters.Type.FieldByName("AccountId") + require.True(t, ok) + assert.Contains(t, string(filterField.Tag), `json:"accountId,omitempty"`) + assert.Contains(t, string(filterField.Tag), `desc:"Account identifier filter"`) + limit, ok := bodyType.FieldByName("Limit") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), limit.Type) + }, + }, + { + name: "uses explicit configured input type", + component: newComponentWithExplicitInput(t), + config: (&Config{Input: "ExplicitReportInput"}).Normalize(), + assertion: func(t *testing.T, got *state.Type, err error) { + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.Type()) + require.NotNil(t, got.Type().Type()) + assert.Equal(t, reflect.TypeOf(explicitReportInput{}), got.Type().Type()) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + metadata, err := AssembleMetadata(test.component, test.config) + require.NoError(t, err) + got, err := BuildInputType(test.component, metadata, test.config) + test.assertion(t, got, err) + }) + } +} + +func newComponentFixture(t *testing.T, reportCfg *Config) *Component { + t.Helper() + resource := view.EmptyResource() + columnResource := &testResource{} + rootView := view.NewView("vendor", "VENDOR") + rootView.Groupable = true + rootView.Selector = &view.Config{ + FieldsParameter: &state.Parameter{Name: "fields", In: state.NewQueryLocation("_fields")}, + OrderByParameter: &state.Parameter{Name: "orderBy", In: state.NewQueryLocation("_orderby")}, + LimitParameter: &state.Parameter{Name: "limit", In: state.NewQueryLocation("_limit")}, + OffsetParameter: &state.Parameter{Name: "offset", In: state.NewQueryLocation("_offset")}, + } + rootView.Columns = []*view.Column{ + view.NewColumn("AccountID", "int", reflect.TypeOf(0), false), + view.NewColumn("UserCreated", "int", reflect.TypeOf(0), false), + view.NewColumn("TotalSpend", "float64", reflect.TypeOf(float64(0)), false), + } + rootView.Columns[0].Groupable = true + rootView.Columns[1].Groupable = true + rootView.Columns[2].Aggregate = true + for _, column := range rootView.Columns { + require.NoError(t, column.Init(columnResource, text.CaseFormatUndefined, false)) + } + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "vendorIDs", In: state.NewQueryLocation("vendorIDs"), Schema: state.NewSchema(reflect.TypeOf([]int{})), Description: "Vendor IDs to include"}, + &state.Parameter{Name: "accountID", In: state.NewQueryLocation("accountID"), Schema: state.NewSchema(reflect.TypeOf(0)), Predicates: []*extension.PredicateConfig{{Name: "ByAccount"}}, Description: "Account identifier filter"}, + &state.Parameter{Name: "fields", In: state.NewQueryLocation("_fields"), Schema: state.NewSchema(reflect.TypeOf([]string{}))}, + }), state.WithResource(columnResource)) + require.NoError(t, err) + inputType.Name = "VendorInput" + + return &Component{ + Name: "vendors", + InputName: inputType.Name, + Parameters: inputType.Parameters, + View: rootView, + Resource: rootView.Resource(), + Report: reportCfg, + } +} + +func newComponentWithoutSelectableColumns(t *testing.T) *Component { + t.Helper() + resource := view.EmptyResource() + columnResource := &testResource{} + rootView := view.NewView("vendor", "VENDOR") + rootView.Groupable = false + rootView.Columns = []*view.Column{ + view.NewColumn("PlainValue", "int", reflect.TypeOf(0), false), + } + for _, column := range rootView.Columns { + require.NoError(t, column.Init(columnResource, text.CaseFormatUndefined, false)) + } + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(nil), state.WithResource(columnResource)) + require.NoError(t, err) + + return &Component{ + Name: "vendors", + InputName: inputType.Name, + Parameters: inputType.Parameters, + View: rootView, + Resource: rootView.Resource(), + Report: &Config{Enabled: true}, + } +} + +func newComponentWithExplicitInput(t *testing.T) *Component { + t.Helper() + component := newComponentFixture(t, &Config{Enabled: true, Input: "ExplicitReportInput"}) + resource := view.EmptyResource() + require.NoError(t, resource.TypeRegistry().Register("ExplicitReportInput", xreflect.WithReflectType(reflect.TypeOf(explicitReportInput{})))) + component.View.SetResource(resource) + component.Resource = component.View.Resource() + return component +} diff --git a/repository/report/model.go b/repository/report/model.go new file mode 100644 index 000000000..e1ac447cf --- /dev/null +++ b/repository/report/model.go @@ -0,0 +1,119 @@ +package report + +import ( + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/view/state" +) + +type Config struct { + Enabled bool + MCPTool *bool + Input string + Dimensions string + Measures string + Filters string + OrderBy string + Limit string + Offset string +} + +type Metadata struct { + InputName string + BodyFieldName string + DimensionsKey string + MeasuresKey string + FiltersKey string + Dimensions []*Field + Measures []*Field + Filters []*Filter + OrderBy string + Limit string + Offset string +} + +type Field struct { + Name string + FieldName string + Section string + Description string +} + +type Filter struct { + Name string + FieldName string + Section string + Description string + Parameter *state.Parameter +} + +func (c *Config) Clone() *Config { + if c == nil { + return nil + } + ret := *c + return &ret +} + +func (c *Config) Normalize() *Config { + if c == nil { + return nil + } + ret := c.Clone() + ret.Input = strings.TrimSpace(ret.Input) + ret.Dimensions = defaultString(ret.Dimensions, "Dimensions") + ret.Measures = defaultString(ret.Measures, "Measures") + ret.Filters = defaultString(ret.Filters, "Filters") + ret.OrderBy = defaultString(ret.OrderBy, "OrderBy") + ret.Limit = defaultString(ret.Limit, "Limit") + ret.Offset = defaultString(ret.Offset, "Offset") + return ret +} + +func (c *Config) MCPToolEnabled() bool { + if c == nil || c.MCPTool == nil { + return true + } + return *c.MCPTool +} + +func (c *Config) InputTypeName(componentName, inputName, viewName string) string { + if c != nil && strings.TrimSpace(c.Input) != "" { + return strings.TrimSpace(c.Input) + } + switch { + case strings.TrimSpace(inputName) != "": + return state.SanitizeTypeName(strings.TrimSpace(inputName) + "ReportInput") + case strings.TrimSpace(componentName) != "": + return state.SanitizeTypeName(strings.TrimSpace(componentName) + "ReportInput") + default: + return state.SanitizeTypeName(strings.TrimSpace(viewName) + "ReportInput") + } +} + +func (m *Metadata) ValidateSelection() error { + if m == nil { + return fmt.Errorf("report metadata was empty") + } + if len(m.Dimensions) == 0 && len(m.Measures) == 0 { + return fmt.Errorf("report metadata had no selectable dimensions or measures") + } + return nil +} + +func (f *Filter) SchemaType() reflect.Type { + if f == nil || f.Parameter == nil || f.Parameter.Schema == nil { + return nil + } + return f.Parameter.OutputType() +} + +func defaultString(value, fallback string) string { + value = strings.TrimSpace(value) + if value == "" { + return fallback + } + return value +} diff --git a/repository/report_handler.go b/repository/report_handler.go new file mode 100644 index 000000000..28ac7ae2f --- /dev/null +++ b/repository/report_handler.go @@ -0,0 +1,286 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "reflect" + "strconv" + "strings" + + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/view/state" + xhandler "github.com/viant/xdatly/handler" + xdhttp "github.com/viant/xdatly/handler/http" +) + +type reportHandler struct { + Dispatcher contract.Dispatcher + Path *contract.Path + Metadata *ReportMetadata + Original *Component + BodyType reflect.Type +} + +func (r *reportHandler) Exec(ctx context.Context, session xhandler.Session) (interface{}, error) { + if r == nil || r.Dispatcher == nil || r.Path == nil || r.Metadata == nil || r.Original == nil { + return nil, fmt.Errorf("report handler was not initialized") + } + request, err := session.Http().NewRequest(ctx) + if err != nil { + return nil, err + } + input, err := r.reportInput(ctx, request) + if err != nil { + return nil, err + } + query, err := r.buildQuery(input) + if err != nil { + return nil, err + } + internalReq := request.Clone(ctx) + internalReq.Method = r.Path.Method + internalReq.URL = cloneURL(request.URL) + internalReq.URL.Path = strings.TrimSuffix(request.URL.Path, "/report") + internalReq.URL.RawPath = internalReq.URL.Path + internalReq.URL.RawQuery = query.Encode() + internalReq.RequestURI = internalReq.URL.RequestURI() + redirect := &xdhttp.Route{URL: r.Path.URI, Method: r.Path.Method} + return nil, session.Http().Redirect(ctx, redirect, internalReq) +} + +func (r *reportHandler) reportInput(ctx context.Context, request *http.Request) (interface{}, error) { + input := ctx.Value(xhandler.InputKey) + if request != nil && request.Body != nil && r.BodyType != nil { + payload, err := io.ReadAll(request.Body) + if err != nil { + return nil, err + } + if len(payload) > 0 { + targetType := r.BodyType + for targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + } + target := reflect.New(targetType) + if err := json.Unmarshal(payload, target.Interface()); err != nil { + return nil, err + } + return target.Interface(), nil + } + } + if input != nil { + return input, nil + } + if input == nil { + return nil, fmt.Errorf("report input was empty") + } + return input, nil +} + +func (r *reportHandler) buildQuery(input interface{}) (url.Values, error) { + root := indirectValue(reflect.ValueOf(input)) + if !root.IsValid() || root.Kind() != reflect.Struct { + return nil, fmt.Errorf("unsupported report input type %T", input) + } + root = bodyRoot(root, r.Metadata.BodyFieldName) + query := url.Values{} + fields, err := r.collectSelections(root, r.Metadata.Dimensions, r.Metadata.Measures) + if err != nil { + return nil, err + } + if len(fields) == 0 { + return nil, fmt.Errorf("report requires at least one dimension or measure") + } + if fieldsParameter := r.Original.View.Selector.FieldsParameter; fieldsParameter != nil && fieldsParameter.In != nil { + query.Set(fieldsParameter.In.Name, strings.Join(fields, ",")) + } + if err := r.collectFilters(root, query); err != nil { + return nil, err + } + if err := r.collectStrings(root, r.Metadata.OrderBy, query, r.selectorName(r.Original.View.Selector.OrderByParameter, "_orderby")); err != nil { + return nil, err + } + if err := r.collectInts(root, r.Metadata.Limit, query, r.selectorName(r.Original.View.Selector.LimitParameter, "_limit")); err != nil { + return nil, err + } + if err := r.collectInts(root, r.Metadata.Offset, query, r.selectorName(r.Original.View.Selector.OffsetParameter, "_offset")); err != nil { + return nil, err + } + return query, nil +} + +func (r *reportHandler) selectorName(parameter *state.Parameter, fallback string) string { + if parameter != nil && parameter.In != nil && strings.TrimSpace(parameter.In.Name) != "" { + return parameter.In.Name + } + return fallback +} + +func (r *reportHandler) collectSelections(root reflect.Value, groups ...[]*ReportField) ([]string, error) { + var result []string + for _, group := range groups { + for _, field := range group { + section := fieldByName(root, field.Section) + if !section.IsValid() { + continue + } + value := fieldByName(indirectValue(section), field.FieldName) + if !value.IsValid() || value.Kind() != reflect.Bool { + continue + } + if value.Bool() { + result = append(result, field.Name) + } + } + } + return result, nil +} + +func (r *reportHandler) collectFilters(root reflect.Value, query url.Values) error { + filters := fieldByName(root, r.Metadata.FiltersKey) + if !filters.IsValid() { + return nil + } + filters = indirectValue(filters) + for _, filter := range r.Metadata.Filters { + value := fieldByName(filters, filter.FieldName) + if !value.IsValid() || isEmptyValue(value) { + continue + } + if filter.Parameter == nil || filter.Parameter.In == nil { + continue + } + appendQueryValue(query, filter.Parameter.In.Name, value) + } + return nil +} + +func (r *reportHandler) collectStrings(root reflect.Value, fieldName string, query url.Values, key string) error { + if fieldName == "" { + return nil + } + value := fieldByName(root, fieldName) + if !value.IsValid() { + return nil + } + value = indirectValue(value) + if !value.IsValid() || value.Kind() != reflect.Slice { + return nil + } + var parts []string + for i := 0; i < value.Len(); i++ { + item := indirectValue(value.Index(i)) + if item.IsValid() && item.Kind() == reflect.String && item.Len() > 0 { + parts = append(parts, item.String()) + } + } + if len(parts) > 0 { + query.Set(key, strings.Join(parts, ",")) + } + return nil +} + +func (r *reportHandler) collectInts(root reflect.Value, fieldName string, query url.Values, key string) error { + if fieldName == "" { + return nil + } + value := fieldByName(root, fieldName) + if !value.IsValid() { + return nil + } + value = indirectValue(value) + if !value.IsValid() { + return nil + } + switch value.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + query.Set(key, strconv.FormatInt(value.Int(), 10)) + } + return nil +} + +func appendQueryValue(query url.Values, key string, value reflect.Value) { + value = indirectValue(value) + switch value.Kind() { + case reflect.String: + if value.String() != "" { + query.Add(key, value.String()) + } + case reflect.Bool: + query.Add(key, strconv.FormatBool(value.Bool())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + query.Add(key, strconv.FormatInt(value.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query.Add(key, strconv.FormatUint(value.Uint(), 10)) + case reflect.Float32, reflect.Float64: + query.Add(key, strconv.FormatFloat(value.Float(), 'f', -1, 64)) + case reflect.Slice, reflect.Array: + for i := 0; i < value.Len(); i++ { + appendQueryValue(query, key, value.Index(i)) + } + } +} + +func fieldByName(root reflect.Value, name string) reflect.Value { + root = indirectValue(root) + if !root.IsValid() || root.Kind() != reflect.Struct || name == "" { + return reflect.Value{} + } + return root.FieldByName(name) +} + +func indirectValue(value reflect.Value) reflect.Value { + for value.IsValid() && value.Kind() == reflect.Ptr { + if value.IsNil() { + return reflect.Value{} + } + value = value.Elem() + } + return value +} + +func isEmptyValue(value reflect.Value) bool { + value = indirectValue(value) + if !value.IsValid() { + return true + } + switch value.Kind() { + case reflect.String, reflect.Array, reflect.Slice, reflect.Map: + return value.Len() == 0 + case reflect.Bool: + return !value.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return value.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return value.Uint() == 0 + case reflect.Float32, reflect.Float64: + return value.Float() == 0 + } + return false +} + +func cloneURL(source *url.URL) *url.URL { + if source == nil { + return &url.URL{} + } + clone := *source + return &clone +} + +func bodyRoot(root reflect.Value, bodyField string) reflect.Value { + if bodyField == "" { + return root + } + body := fieldByName(root, bodyField) + if !body.IsValid() { + return root + } + body = indirectValue(body) + if !body.IsValid() || body.Kind() != reflect.Struct { + return root + } + return body +} diff --git a/repository/report_handler_test.go b/repository/report_handler_test.go new file mode 100644 index 000000000..43b86214d --- /dev/null +++ b/repository/report_handler_test.go @@ -0,0 +1,211 @@ +package repository + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + xhandler "github.com/viant/xdatly/handler" + xdauth "github.com/viant/xdatly/handler/auth" + "github.com/viant/xdatly/handler/differ" + xdhttp "github.com/viant/xdatly/handler/http" + xdlogger "github.com/viant/xdatly/handler/logger" + "github.com/viant/xdatly/handler/mbus" + "github.com/viant/xdatly/handler/sqlx" + xdstate "github.com/viant/xdatly/handler/state" + "github.com/viant/xdatly/handler/validator" +) + +type captureDispatcher struct { + path *contract.Path + options *contract.Options +} + +func (d *captureDispatcher) Dispatch(ctx context.Context, path *contract.Path, options ...contract.Option) (interface{}, error) { + d.path = path + d.options = contract.NewOptions(options...) + return map[string]string{"status": "ok"}, nil +} + +type reportTestHTTP struct { + request *http.Request + redirectRoute *xdhttp.Route + redirectRequest *http.Request +} + +func (h *reportTestHTTP) RequestOf(ctx context.Context, v any) (*http.Request, error) { + return h.request, nil +} +func (h *reportTestHTTP) NewRequest(ctx context.Context, opts ...xdstate.Option) (*http.Request, error) { + return h.request, nil +} +func (h *reportTestHTTP) Redirect(ctx context.Context, route *xdhttp.Route, request *http.Request) error { + h.redirectRoute = route + h.redirectRequest = request + return nil +} +func (h *reportTestHTTP) FailWithCode(statusCode int, err error) error { return err } + +type reportTestLogger struct{} + +func (l *reportTestLogger) IsDebugEnabled() bool { return false } +func (l *reportTestLogger) IsInfoEnabled() bool { return false } +func (l *reportTestLogger) IsWarnEnabled() bool { return false } +func (l *reportTestLogger) IsErrorEnabled() bool { return false } +func (l *reportTestLogger) Info(msg string, args ...any) {} +func (l *reportTestLogger) Debug(msg string, args ...any) {} +func (l *reportTestLogger) Warn(msg string, args ...any) {} +func (l *reportTestLogger) Error(msg string, args ...any) {} +func (l *reportTestLogger) Infoc(ctx context.Context, msg string, args ...any) {} +func (l *reportTestLogger) Debugc(ctx context.Context, msg string, args ...any) {} +func (l *reportTestLogger) DebugJSONc(ctx context.Context, msg string, obj any) {} +func (l *reportTestLogger) Warnc(ctx context.Context, msg string, args ...any) {} +func (l *reportTestLogger) Errorc(ctx context.Context, msg string, args ...any) {} +func (l *reportTestLogger) Infos(ctx context.Context, msg string, attrs ...slog.Attr) {} +func (l *reportTestLogger) Debugs(ctx context.Context, msg string, attrs ...slog.Attr) {} +func (l *reportTestLogger) Warns(ctx context.Context, msg string, attrs ...slog.Attr) {} +func (l *reportTestLogger) Errors(ctx context.Context, msg string, attrs ...slog.Attr) {} + +type reportTestSession struct { + http *reportTestHTTP + logger xdlogger.Logger +} + +type reportHandlerDimensions struct { + AccountID bool +} + +type reportHandlerMeasures struct { + TotalSpend bool +} + +type reportHandlerFilters struct { + AccountID *int +} + +type reportHandlerBody struct { + Dimensions reportHandlerDimensions + Measures reportHandlerMeasures + Filters reportHandlerFilters + OrderBy []string + Limit *int + Offset *int +} + +func (s *reportTestSession) Validator() *validator.Service { return nil } +func (s *reportTestSession) Differ() *differ.Service { return nil } +func (s *reportTestSession) MessageBus() *mbus.Service { return nil } +func (s *reportTestSession) Db(opts ...sqlx.Option) (*sqlx.Service, error) { return nil, nil } +func (s *reportTestSession) Stater() *xdstate.Service { return nil } +func (s *reportTestSession) FlushTemplate(ctx context.Context) error { return nil } +func (s *reportTestSession) Session(ctx context.Context, route *xdhttp.Route, opts ...xdstate.Option) (xhandler.Session, error) { + return s, nil +} +func (s *reportTestSession) Http() xdhttp.Http { return s.http } +func (s *reportTestSession) Auth() xdauth.Auth { return nil } +func (s *reportTestSession) Logger() xdlogger.Logger { return s.logger } + +func testReportHandler() *reportHandler { + return &reportHandler{ + Dispatcher: &captureDispatcher{}, + Path: &contract.Path{Method: http.MethodGet, URI: "/v1/api/vendors"}, + Metadata: &ReportMetadata{ + BodyFieldName: "", + DimensionsKey: "Dimensions", + MeasuresKey: "Measures", + FiltersKey: "Filters", + OrderBy: "OrderBy", + Limit: "Limit", + Offset: "Offset", + Dimensions: []*ReportField{{Name: "AccountID", FieldName: "AccountID", Section: "Dimensions"}}, + Measures: []*ReportField{{Name: "TotalSpend", FieldName: "TotalSpend", Section: "Measures"}}, + Filters: []*ReportFilter{{Name: "accountID", FieldName: "AccountID"}}, + }, + Original: &Component{ + View: &view.View{ + Selector: &view.Config{ + FieldsParameter: &state.Parameter{In: state.NewQueryLocation("_fields")}, + OrderByParameter: &state.Parameter{In: state.NewQueryLocation("_orderby")}, + LimitParameter: &state.Parameter{In: state.NewQueryLocation("_limit")}, + OffsetParameter: &state.Parameter{In: state.NewQueryLocation("_offset")}, + }, + }, + }, + } +} + +func testReportInput() reportHandlerBody { + accountID := 101 + limit := 25 + return reportHandlerBody{ + Dimensions: reportHandlerDimensions{AccountID: true}, + Measures: reportHandlerMeasures{TotalSpend: true}, + Filters: reportHandlerFilters{AccountID: &accountID}, + OrderBy: []string{"AccountID"}, + Limit: &limit, + } +} + +func TestReportHandler_BuildQuery_FromPostBody(t *testing.T) { + handler := testReportHandler() + handler.Metadata.Filters[0].Parameter = &state.Parameter{In: state.NewQueryLocation("accountID")} + query, err := handler.buildQuery(testReportInput()) + require.NoError(t, err) + assert.Equal(t, "AccountID,TotalSpend", query.Get("_fields")) + assert.Equal(t, "AccountID", query.Get("_orderby")) + assert.Equal(t, "25", query.Get("_limit")) + assert.Equal(t, "101", query.Get("accountID")) +} + +func TestReportHandler_Exec_PreservesAuthorizationHeader(t *testing.T) { + handler := testReportHandler() + handler.Metadata.Filters[0].Parameter = &state.Parameter{In: state.NewQueryLocation("accountID")} + + req := httptest.NewRequest(http.MethodPost, "http://localhost/v1/api/vendors/report", nil) + req.Header.Set("Authorization", "Bearer test-token") + httpSession := &reportTestHTTP{request: req} + session := &reportTestSession{ + http: httpSession, + logger: &reportTestLogger{}, + } + + ctx := context.WithValue(context.Background(), xhandler.InputKey, testReportInput()) + _, err := handler.Exec(ctx, session) + require.NoError(t, err) + require.NotNil(t, httpSession.redirectRoute) + require.NotNil(t, httpSession.redirectRequest) + assert.Equal(t, "Bearer test-token", httpSession.redirectRequest.Header.Get("Authorization")) + assert.Equal(t, "/v1/api/vendors", httpSession.redirectRequest.URL.Path) + assert.Equal(t, http.MethodGet, httpSession.redirectRoute.Method) + assert.Equal(t, "/v1/api/vendors", httpSession.redirectRoute.URL) + query := httpSession.redirectRequest.URL.Query() + assert.Equal(t, "AccountID,TotalSpend", query.Get("_fields")) + assert.Equal(t, "AccountID", query.Get("_orderby")) + assert.Equal(t, "25", query.Get("_limit")) + assert.Equal(t, "101", query.Get("accountID")) +} + +func TestReportHandler_ReportInput_AcceptsUnwrappedBody(t *testing.T) { + handler := testReportHandler() + handler.BodyType = reflect.TypeOf(&reportHandlerBody{}) + payload, err := json.Marshal(testReportInput()) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "http://localhost/v1/api/vendors/report", io.NopCloser(bytes.NewReader(payload))) + input, err := handler.reportInput(context.Background(), req) + require.NoError(t, err) + body, ok := input.(*reportHandlerBody) + require.True(t, ok) + require.True(t, body.Dimensions.AccountID) + require.True(t, body.Measures.TotalSpend) +} diff --git a/repository/report_runtime.go b/repository/report_runtime.go new file mode 100644 index 000000000..b7d685e88 --- /dev/null +++ b/repository/report_runtime.go @@ -0,0 +1,375 @@ +package repository + +import ( + "context" + "embed" + "fmt" + "net/http" + "reflect" + "strconv" + "strings" + + "github.com/viant/datly/repository/contract" + rephandler "github.com/viant/datly/repository/handler" + "github.com/viant/datly/repository/path" + reportmodel "github.com/viant/datly/repository/report" + "github.com/viant/datly/service" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" + "github.com/viant/xdatly/codec" + "github.com/viant/xreflect" +) + +func (s *Service) appendReportProvider(ctx context.Context, item *path.Item, routePath *path.Path, providers []*Provider, provider *Provider) ([]*Provider, error) { + if routePath == nil || routePath.Report == nil || !routePath.Report.Enabled { + return providers, nil + } + reportPath := buildReportPath(routePath) + reportProvider := &Provider{ + path: reportPath.Path, + control: routePath.Version, + newComponent: func(ctx context.Context, opts ...Option) (*Component, error) { + original, err := provider.Component(ctx, opts...) + if err != nil || original == nil { + return nil, err + } + if !isReportEligible(original) { + return nil, nil + } + component, _, err := buildReportArtifacts(ctx, s.registry.Dispatcher(), original, routePath) + return component, err + }, + } + item.Paths = append(item.Paths, reportPath) + providers = append(providers, reportProvider) + return providers, nil +} + +func isReportEligible(component *Component) bool { + if component == nil || component.Report == nil || !component.Report.Enabled { + return false + } + if component.View == nil || !component.View.Groupable { + return false + } + return strings.EqualFold(component.Method, http.MethodGet) +} + +func (s *Service) buildReportComponent(original *Component, routePath *path.Path) (*Component, *path.Path, error) { + return buildReportArtifacts(context.Background(), s.registry.Dispatcher(), original, routePath) +} + +func BuildReportComponent(dispatcher contract.Dispatcher, original *Component) (*Component, error) { + component, _, err := buildReportArtifacts(context.Background(), dispatcher, original, nil) + return component, err +} + +func buildReportArtifacts(ctx context.Context, dispatcher contract.Dispatcher, original *Component, routePath *path.Path) (*Component, *path.Path, error) { + config := original.Report.Normalize() + metadata, err := buildReportMetadata(original, config) + if err != nil { + return nil, nil, err + } + inputType, err := buildReportInputType(original, metadata, config) + if err != nil { + return nil, nil, err + } + reportURI := strings.TrimSuffix(original.URI, "/") + "/report" + ret := *original + ret.Path = contract.Path{Method: http.MethodPost, URI: reportURI} + ret.Handler = rephandler.NewHandler(&reportHandler{ + Dispatcher: dispatcher, + Path: &original.Path, + Metadata: metadata, + Original: original, + BodyType: inputType.Schema.Type(), + }) + ret.Service = service.TypeExecutor + ret.Report = config + ret.View = buildReportWrapperView(original.View) + ret.Async = nil + ret.Input.Type = *inputType + var reportPath *path.Path + if routePath != nil { + pathCopy := *routePath + pathCopy.Path = ret.Path + pathCopy.View = routePath.View + pathCopy.Internal = routePath.Internal + pathCopy.Meta = routePath.Meta + pathCopy.ModelContextProtocol = routePath.ModelContextProtocol + pathCopy.MCPTool = config.MCPToolEnabled() + pathCopy.MCPResource = false + pathCopy.MCPTemplateResource = false + pathCopy.Report = routePath.Report + if pathCopy.Name != "" { + pathCopy.Name += " Report" + } + if pathCopy.Description != "" { + pathCopy.Description += " report" + } + reportPath = &pathCopy + } + return &ret, reportPath, nil +} + +func buildReportWrapperView(original *view.View) *view.View { + if original == nil { + return nil + } + ret := &view.View{ + Name: original.Name + "#report", + Description: original.Description, + Module: original.Module, + Alias: original.Alias, + Mode: view.ModeHandler, + Connector: original.Connector, + CaseFormat: original.CaseFormat, + Groupable: original.Groupable, + Selector: &view.Config{}, + } + if original.Schema != nil { + ret.Schema = original.Schema.Clone() + } + ret.SetResource(original.GetResource()) + return ret +} + +func buildReportPath(routePath *path.Path) *path.Path { + pathCopy := *routePath + pathCopy.Path = contract.Path{ + Method: http.MethodPost, + URI: strings.TrimSuffix(routePath.URI, "/") + "/report", + } + pathCopy.MCPTool = reportPathMCPToolEnabled(routePath.Report) + pathCopy.MCPResource = false + pathCopy.MCPTemplateResource = false + if pathCopy.Name != "" { + pathCopy.Name += " Report" + } + if pathCopy.Description != "" { + pathCopy.Description += " report" + } + return &pathCopy +} + +func reportPathMCPToolEnabled(report *path.Report) bool { + if report == nil || report.MCPTool == nil { + return true + } + return *report.MCPTool +} + +func buildReportMetadata(component *Component, report *Report) (*ReportMetadata, error) { + source := &reportmodel.Component{ + Name: component.Name, + InputName: component.Input.Type.Name, + Parameters: component.Input.Type.Parameters, + View: component.View, + Resource: component.View.Resource(), + Report: report, + } + return reportmodel.AssembleMetadata(source, report) +} + +func buildReportInputType(component *Component, metadata *ReportMetadata, report *Report) (*state.Type, error) { + source := &reportmodel.Component{ + Name: component.Name, + InputName: component.Input.Type.Name, + Parameters: component.Input.Type.Parameters, + View: component.View, + Resource: component.View.Resource(), + Report: report, + } + return reportmodel.BuildInputType(source, metadata, report) +} + +func validateExplicitReportInput(inputType *state.Type, metadata *ReportMetadata) error { + if inputType == nil || inputType.Type() == nil { + return fmt.Errorf("explicit report input type was empty") + } + rType := inputType.Type().Type() + if rType == nil { + return fmt.Errorf("explicit report input state type was empty") + } + rType = reflectTypeOfState(rType) + for _, fieldName := range []string{metadata.DimensionsKey, metadata.MeasuresKey, metadata.FiltersKey, metadata.OrderBy, metadata.Limit, metadata.Offset} { + if fieldName == "" { + continue + } + if _, ok := rType.FieldByName(fieldName); !ok { + return fmt.Errorf("explicit report input %s missing field %s", rType.String(), fieldName) + } + } + return nil +} + +func synthesizeReportBodyType(metadata *ReportMetadata) reflect.Type { + var fields []reflect.StructField + fields = append(fields, reflect.StructField{ + Name: metadata.DimensionsKey, + Type: sectionStructType(metadata.Dimensions), + Tag: buildReportTag(lowerCamel(metadata.DimensionsKey), "Selected grouping dimensions"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.MeasuresKey, + Type: sectionStructType(metadata.Measures), + Tag: buildReportTag(lowerCamel(metadata.MeasuresKey), "Selected aggregate measures"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.FiltersKey, + Type: filterStructType(metadata.Filters), + Tag: buildReportTag(lowerCamel(metadata.FiltersKey), "Report filters derived from original predicate parameters"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.OrderBy, + Type: reflect.TypeOf([]string{}), + Tag: buildReportTag(lowerCamel(metadata.OrderBy), "Ordering expressions applied to the grouped result"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.Limit, + Type: reflect.TypeOf((*int)(nil)), + Tag: buildReportTag(lowerCamel(metadata.Limit), "Maximum number of grouped rows to return"), + }) + fields = append(fields, reflect.StructField{ + Name: metadata.Offset, + Type: reflect.TypeOf((*int)(nil)), + Tag: buildReportTag(lowerCamel(metadata.Offset), "Row offset applied to the grouped result"), + }) + return reflect.StructOf(fields) +} + +func sectionStructType(fields []*ReportField) reflect.Type { + if len(fields) == 0 { + return reflect.TypeOf(struct{}{}) + } + structFields := make([]reflect.StructField, 0, len(fields)) + for _, field := range fields { + structFields = append(structFields, reflect.StructField{ + Name: field.FieldName, + Type: reflect.TypeOf(false), + Tag: buildReportTag(lowerCamel(field.Name), field.Description), + }) + } + return reflect.StructOf(structFields) +} + +func filterStructType(filters []*ReportFilter) reflect.Type { + if len(filters) == 0 { + return reflect.TypeOf(struct{}{}) + } + structFields := make([]reflect.StructField, 0, len(filters)) + for _, filter := range filters { + rType := reflect.TypeOf("") + if schemaType := filter.SchemaType(); schemaType != nil { + rType = schemaType + } + structFields = append(structFields, reflect.StructField{ + Name: filter.FieldName, + Type: rType, + Tag: buildReportTag(lowerCamel(filter.Name), filter.Description), + }) + } + return reflect.StructOf(structFields) +} + +func buildReportTag(jsonName, description string) reflect.StructTag { + result := fmt.Sprintf(`json:"%s,omitempty"`, jsonName) + if description = strings.TrimSpace(description); description != "" { + result += " desc:" + strconv.Quote(description) + } + return reflect.StructTag(result) +} + +func isSelectorParameter(parameter *state.Parameter, aView *view.View) bool { + if parameter == nil || parameter.In == nil { + return false + } + if aView != nil && aView.Selector != nil { + for _, selector := range []*state.Parameter{ + aView.Selector.FieldsParameter, + aView.Selector.OrderByParameter, + aView.Selector.LimitParameter, + aView.Selector.OffsetParameter, + aView.Selector.PageParameter, + } { + if selector != nil && selector.In != nil && selector.In.Name == parameter.In.Name { + return true + } + } + } + name := strings.ToLower(parameter.In.Name) + return name == "_fields" || name == "_orderby" || name == "_limit" || name == "_offset" || name == "_page" || name == "criteria" +} + +func lowerCamel(value string) string { + if value == "" { + return "" + } + return text.CaseFormatUpperCamel.Format(value, text.CaseFormatLowerCamel) +} + +func exportedReportFieldName(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return state.SanitizeTypeName(value) +} + +func reflectTypeOfState(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + if rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} + +type reportInputResource struct { + base state.Resource +} + +func newReportInputResource(base state.Resource) state.Resource { + return &reportInputResource{base: base} +} + +func (r *reportInputResource) LookupParameter(name string) (*state.Parameter, error) { return nil, nil } +func (r *reportInputResource) AppendParameter(parameter *state.Parameter) {} +func (r *reportInputResource) ViewSchema(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *reportInputResource) ViewSchemaPointer(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *reportInputResource) LookupType() xreflect.LookupType { return nil } +func (r *reportInputResource) LoadText(ctx context.Context, URL string) (string, error) { + return "", nil +} +func (r *reportInputResource) Codecs() *codec.Registry { + if r.base != nil && r.base.Codecs() != nil { + return r.base.Codecs() + } + return codec.New() +} +func (r *reportInputResource) CodecOptions() *codec.Options { + if r.base != nil && r.base.CodecOptions() != nil { + return r.base.CodecOptions() + } + return codec.NewOptions(nil) +} +func (r *reportInputResource) ExpandSubstitutes(value string) string { + if r.base != nil { + return r.base.ExpandSubstitutes(value) + } + return value +} +func (r *reportInputResource) ReverseSubstitutes(value string) string { + if r.base != nil { + return r.base.ReverseSubstitutes(value) + } + return value +} +func (r *reportInputResource) EmbedFS() *embed.FS { return nil } +func (r *reportInputResource) SetFSEmbedder(embedder *state.FSEmbedder) {} diff --git a/repository/report_runtime_test.go b/repository/report_runtime_test.go new file mode 100644 index 000000000..c57cbbc8f --- /dev/null +++ b/repository/report_runtime_test.go @@ -0,0 +1,359 @@ +package repository + +import ( + "context" + "embed" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/path" + "github.com/viant/datly/view" + "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" + "github.com/viant/xdatly/codec" + "github.com/viant/xreflect" +) + +type reportTestResource struct{} + +func (r *reportTestResource) LookupParameter(name string) (*state.Parameter, error) { return nil, nil } +func (r *reportTestResource) AppendParameter(parameter *state.Parameter) {} +func (r *reportTestResource) ViewSchema(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *reportTestResource) ViewSchemaPointer(ctx context.Context, name string) (*state.Schema, error) { + return nil, nil +} +func (r *reportTestResource) LookupType() xreflect.LookupType { return nil } +func (r *reportTestResource) LoadText(ctx context.Context, URL string) (string, error) { + return "", nil +} +func (r *reportTestResource) Codecs() *codec.Registry { return codec.New() } +func (r *reportTestResource) CodecOptions() *codec.Options { return codec.NewOptions(nil) } +func (r *reportTestResource) ExpandSubstitutes(value string) string { return value } +func (r *reportTestResource) ReverseSubstitutes(value string) string { return value } +func (r *reportTestResource) EmbedFS() *embed.FS { return nil } +func (r *reportTestResource) SetFSEmbedder(embedder *state.FSEmbedder) {} + +func TestBuildReportMetadataAndComponent(t *testing.T) { + resource := view.EmptyResource() + columnResource := &reportTestResource{} + rootView := view.NewView("vendor", "VENDOR") + rootView.Groupable = true + rootView.Selector = &view.Config{ + FieldsParameter: &state.Parameter{Name: "fields", In: state.NewQueryLocation("_fields")}, + OrderByParameter: &state.Parameter{Name: "orderBy", In: state.NewQueryLocation("_orderby")}, + LimitParameter: &state.Parameter{Name: "limit", In: state.NewQueryLocation("_limit")}, + OffsetParameter: &state.Parameter{Name: "offset", In: state.NewQueryLocation("_offset")}, + } + rootView.Columns = []*view.Column{ + view.NewColumn("AccountID", "int", reflect.TypeOf(0), false), + view.NewColumn("UserCreated", "int", reflect.TypeOf(0), false), + view.NewColumn("TotalSpend", "float64", reflect.TypeOf(float64(0)), false), + } + rootView.Columns[0].Groupable = true + rootView.Columns[1].Groupable = true + rootView.Columns[2].Aggregate = true + for _, column := range rootView.Columns { + require.NoError(t, column.Init(columnResource, text.CaseFormatUndefined, false)) + } + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "vendorIDs", In: state.NewQueryLocation("vendorIDs"), Schema: state.NewSchema(reflect.TypeOf([]int{})), Description: "Vendor IDs to include"}, + &state.Parameter{Name: "accountID", In: state.NewQueryLocation("accountID"), Schema: state.NewSchema(reflect.TypeOf(0)), Predicates: []*extension.PredicateConfig{{Name: "ByAccount"}}, Description: "Account identifier filter"}, + &state.Parameter{Name: "fields", In: state.NewQueryLocation("_fields"), Schema: state.NewSchema(reflect.TypeOf([]string{}))}, + }), state.WithResource(columnResource)) + require.NoError(t, err) + inputType.Name = "VendorInput" + + component := &Component{ + Path: contract.Path{Method: "GET", URI: "/v1/api/vendors"}, + Meta: contract.Meta{Name: "vendors"}, + View: rootView, + Report: (&Report{Enabled: true}).Normalize(), + Contract: contract.Contract{ + Input: contract.Input{Type: *inputType}, + }, + } + + metadata, err := buildReportMetadata(component, component.Report) + require.NoError(t, err) + require.NotNil(t, metadata) + assert.Equal(t, "VendorInputReportInput", metadata.InputName) + require.Len(t, metadata.Dimensions, 2) + require.Len(t, metadata.Measures, 1) + require.Len(t, metadata.Filters, 1) + assert.Equal(t, "AccountID", metadata.Dimensions[0].Name) + assert.Equal(t, "TotalSpend", metadata.Measures[0].Name) + assert.Equal(t, "accountID", metadata.Filters[0].Name) + + service := &Service{registry: NewRegistry("", nil, nil)} + reportComponent, reportPath, err := service.buildReportComponent(component, &path.Path{ + Path: component.Path, + View: &path.ViewRef{Ref: rootView.Name}, + ModelContextProtocol: contract.ModelContextProtocol{ + MCPTool: true, + }, + Meta: contract.Meta{ + Name: "vendors", + Description: "Vendor listing", + }, + Report: &path.Report{Enabled: true}, + }) + require.NoError(t, err) + require.NotNil(t, reportComponent) + require.NotNil(t, reportPath) + assert.Equal(t, "POST", reportComponent.Method) + assert.Equal(t, "/v1/api/vendors/report", reportComponent.URI) + require.NotNil(t, reportComponent.Report) + require.NotNil(t, reportComponent.View) + assert.NotSame(t, component.View, reportComponent.View) + assert.Equal(t, view.ModeHandler, reportComponent.View.Mode) + assert.Nil(t, reportComponent.View.Template) + require.Len(t, reportComponent.Input.Type.Parameters, 1) + assert.True(t, reportComponent.Input.Type.Parameters[0].IsAnonymous()) + assert.Equal(t, "/v1/api/vendors/report", reportPath.URI) + assert.Equal(t, "POST", reportPath.Method) + assert.True(t, reportPath.MCPTool) + assert.Equal(t, "vendors Report", reportPath.Name) + assert.Equal(t, "Vendor listing report", reportPath.Description) + reportInputType, err := buildReportInputType(component, metadata, component.Report) + require.NoError(t, err) + require.NotNil(t, reportInputType) + require.NotNil(t, reportInputType.Schema) + require.NotNil(t, reportInputType.Schema.Type()) + bodyType := reportInputType.Schema.Type() + if bodyType.Kind() == reflect.Ptr { + bodyType = bodyType.Elem() + } + _, ok := bodyType.FieldByName("Dimensions") + assert.True(t, ok) + _, ok = bodyType.FieldByName("Measures") + assert.True(t, ok) + _, ok = bodyType.FieldByName("Filters") + assert.True(t, ok) + filtersField, ok := bodyType.FieldByName("Filters") + require.True(t, ok) + filterType := filtersField.Type + require.Greater(t, filterType.NumField(), 0) + filterField := filterType.Field(0) + assert.True(t, strings.Contains(string(filterField.Tag), `desc:"Account identifier filter"`)) +} + +func TestBuildReportComponent_EnablesMCPToolOnSiblingRoute(t *testing.T) { + resource := view.EmptyResource() + rootView := view.NewView("vendor", "VENDOR") + rootView.Groupable = true + rootView.Columns = []*view.Column{ + view.NewColumn("AccountID", "int", reflect.TypeOf(0), false), + view.NewColumn("TotalSpend", "float64", reflect.TypeOf(float64(0)), false), + } + rootView.Columns[0].Groupable = true + rootView.Columns[1].Aggregate = true + for _, column := range rootView.Columns { + require.NoError(t, column.Init(&reportTestResource{}, text.CaseFormatUndefined, false)) + } + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "accountID", In: state.NewQueryLocation("accountID"), Schema: state.NewSchema(reflect.TypeOf(0)), Predicates: []*extension.PredicateConfig{{Name: "ByAccount"}}, Description: "Account identifier filter"}, + }), state.WithResource(&reportTestResource{})) + require.NoError(t, err) + inputType.Name = "VendorInput" + + component := &Component{ + Path: contract.Path{Method: "GET", URI: "/v1/api/vendors"}, + Meta: contract.Meta{Name: "vendors"}, + View: rootView, + Report: (&Report{Enabled: true}).Normalize(), + Contract: contract.Contract{ + Input: contract.Input{Type: *inputType}, + }, + } + + service := &Service{registry: NewRegistry("", nil, nil)} + _, reportPath, err := service.buildReportComponent(component, &path.Path{ + Path: component.Path, + View: &path.ViewRef{Ref: rootView.Name}, + ModelContextProtocol: contract.ModelContextProtocol{ + MCPTool: false, + MCPResource: true, + MCPTemplateResource: true, + }, + Meta: contract.Meta{ + Name: "vendors", + Description: "Vendor listing", + }, + Report: &path.Report{Enabled: true}, + }) + require.NoError(t, err) + require.NotNil(t, reportPath) + assert.True(t, reportPath.MCPTool) + assert.False(t, reportPath.MCPResource) + assert.False(t, reportPath.MCPTemplateResource) +} + +func TestBuildReportComponent_DisablesMCPToolWhenReportFlagIsFalse(t *testing.T) { + resource := view.EmptyResource() + rootView := view.NewView("vendor", "VENDOR") + rootView.Groupable = true + rootView.Columns = []*view.Column{ + view.NewColumn("AccountID", "int", reflect.TypeOf(0), false), + view.NewColumn("TotalSpend", "float64", reflect.TypeOf(float64(0)), false), + } + rootView.Columns[0].Groupable = true + rootView.Columns[1].Aggregate = true + for _, column := range rootView.Columns { + require.NoError(t, column.Init(&reportTestResource{}, text.CaseFormatUndefined, false)) + } + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "accountID", In: state.NewQueryLocation("accountID"), Schema: state.NewSchema(reflect.TypeOf(0)), Predicates: []*extension.PredicateConfig{{Name: "ByAccount"}}, Description: "Account identifier filter"}, + }), state.WithResource(&reportTestResource{})) + require.NoError(t, err) + inputType.Name = "VendorInput" + + disabled := false + component := &Component{ + Path: contract.Path{Method: "GET", URI: "/v1/api/vendors"}, + Meta: contract.Meta{Name: "vendors"}, + View: rootView, + Report: (&Report{ + Enabled: true, + MCPTool: &disabled, + }).Normalize(), + Contract: contract.Contract{ + Input: contract.Input{Type: *inputType}, + }, + } + + service := &Service{registry: NewRegistry("", nil, nil)} + _, reportPath, err := service.buildReportComponent(component, &path.Path{ + Path: component.Path, + View: &path.ViewRef{Ref: rootView.Name}, + ModelContextProtocol: contract.ModelContextProtocol{ + MCPTool: true, + MCPResource: true, + MCPTemplateResource: true, + }, + Meta: contract.Meta{ + Name: "vendors", + Description: "Vendor listing", + }, + Report: &path.Report{Enabled: true, MCPTool: &disabled}, + }) + require.NoError(t, err) + require.NotNil(t, reportPath) + assert.False(t, reportPath.MCPTool) + assert.False(t, reportPath.MCPResource) + assert.False(t, reportPath.MCPTemplateResource) +} + +func TestService_InitComponentProviders_RegistersLocalGroupingReportRoute(t *testing.T) { + ctx := context.Background() + baseDir, err := filepath.Abs(filepath.Join("..", "e2e", "local", "regression")) + require.NoError(t, err) + if _, err := os.Stat(filepath.Join(baseDir, "paths.yaml")); err != nil { + t.Skipf("missing local regression fixture: %v", err) + } + service, err := New(ctx, + WithComponentURL(baseDir), + WithResourceURL(baseDir), + WithNoPlugin(), + WithRefreshDisabled(true), + ) + require.NoError(t, err) + reportPath := &contract.Path{Method: "POST", URI: "/v1/api/dev/vendors-grouping/report"} + provider, err := service.Registry().LookupProvider(ctx, reportPath) + require.NoError(t, err) + require.NotNil(t, provider) + component, err := provider.Component(ctx) + require.NoError(t, err) + require.NotNil(t, component) + require.NotNil(t, component.Report) + assert.True(t, component.Report.Enabled) + assert.Equal(t, "POST", component.Method) + assert.Equal(t, "/v1/api/dev/vendors-grouping/report", component.URI) +} + +func TestBuildReportComponent_DoesNotStripOriginalViewTypeDefinitionsFromCodegen(t *testing.T) { + resource := view.EmptyResource() + rootView := view.NewView("metrics_view", "metrics_view") + rootView.Groupable = true + rootView.Connector = &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Name: "dev"}}} + rootView.Template = &view.Template{Source: "SELECT agency_id, SUM(total_spend) AS total_spend FROM metrics_view GROUP BY 1"} + rootView.Schema = state.NewSchema(reflect.TypeOf([]*struct { + AgencyId *int `sqlx:"agency_id"` + TotalSpend *float64 `sqlx:"total_spend"` + }{})) + rootView.Columns = []*view.Column{ + view.NewColumn("AgencyId", "int", reflect.TypeOf(0), false), + view.NewColumn("TotalSpend", "float64", reflect.TypeOf(float64(0)), false), + } + rootView.Columns[0].Groupable = true + rootView.Columns[1].Aggregate = true + for _, column := range rootView.Columns { + require.NoError(t, column.Init(&reportTestResource{}, text.CaseFormatUndefined, false)) + } + resource.Types = []*view.TypeDefinition{ + {Name: "MetricsViewView", Package: "metrics", DataType: `struct{AgencyId *int ` + "`sqlx:\"agency_id\"`" + `; TotalSpend *float64 ` + "`sqlx:\"total_spend\"`" + `;}`}, + } + require.NoError(t, resource.TypeRegistry().Register("MetricsViewView", xreflect.WithPackage("metrics"), xreflect.WithReflectType(reflect.TypeOf(struct { + AgencyId *int `sqlx:"agency_id"` + TotalSpend *float64 `sqlx:"total_spend"` + }{})))) + rootView.SetResource(resource) + resource.AddViews(rootView) + + inputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "agencyID", In: state.NewQueryLocation("agency_id"), Schema: state.NewSchema(reflect.TypeOf(0)), Predicates: []*extension.PredicateConfig{{Name: "ByAgency"}}, Description: "Agency filter"}, + }), state.WithResource(&reportTestResource{})) + require.NoError(t, err) + inputType.Name = "MetricsViewInput" + + outputType, err := state.NewType(state.WithParameters(state.Parameters{ + &state.Parameter{Name: "Data", In: state.NewOutputLocation("view"), Schema: &state.Schema{Name: "MetricsViewView", Package: "metrics", Cardinality: state.Many}}, + }), state.WithResource(rootView.Resource())) + require.NoError(t, err) + outputType.Name = "MetricsViewOutput" + + component := &Component{ + Path: contract.Path{Method: "GET", URI: "/v1/api/core/metrics/performance_summary"}, + Meta: contract.Meta{Name: "MetricsPerformance"}, + View: rootView, + Report: (&Report{Enabled: true}).Normalize(), + Contract: contract.Contract{ + Input: contract.Input{Type: *inputType}, + Output: contract.Output{Type: *outputType}, + }, + } + + before := component.GenerateOutputCode(context.Background(), true, false, nil) + require.Contains(t, before, "type Data struct") + + service := &Service{registry: NewRegistry("", nil, nil)} + _, _, err = service.buildReportComponent(component, &path.Path{ + Path: component.Path, + View: &path.ViewRef{Ref: rootView.Name}, + Report: &path.Report{ + Enabled: true, + }, + }) + require.NoError(t, err) + + after := component.GenerateOutputCode(context.Background(), true, false, nil) + require.Contains(t, after, "type Data struct") + assert.Equal(t, before, after) +} diff --git a/repository/service.go b/repository/service.go index fdd95286c..fcd7cbf3b 100644 --- a/repository/service.go +++ b/repository/service.go @@ -34,6 +34,7 @@ type ( auth *auth.Service plugins *plugin.Service refreshFrequency time.Duration + refreshDisabled bool options *Options } @@ -68,6 +69,9 @@ func (s *Service) Container() *path.Container { // SyncChanges checks if resource, plugin or components have changes // if so it would increase individual or all component/paths version number resulting in lazy reload func (s *Service) SyncChanges(ctx context.Context) (bool, error) { + if s == nil || s.refreshDisabled { + return false, nil + } now := time.Now() //fmt.Printf("[INFO] sync changes started\n") snap := &snapshot{} @@ -243,6 +247,7 @@ func (s *Service) initComponentProviders(ctx context.Context) error { paths := s.paths.GetPaths() pathsLen := len(paths.Items) var providers []*Provider + var err error for i := 0; i < pathsLen; i++ { route := paths.Items[i] sourceURL := route.SourceURL @@ -259,6 +264,10 @@ func (s *Service) initComponentProviders(ctx context.Context) error { return nil, fmt.Errorf("no component for path: %s", aPath.Path.Key()) }) providers = append(providers, provider) + providers, err = s.appendReportProvider(ctx, route, aPath, providers, provider) + if err != nil { + return err + } } } s.registry.SetProviders(providers) @@ -346,6 +355,7 @@ func New(ctx context.Context, opts ...Option) (*Service, error) { ret := &Service{ options: options, refreshFrequency: options.refreshFrequency, + refreshDisabled: options.refreshDisabled, resources: options.resources, extensions: options.extensions, } diff --git a/repository/service_refresh_test.go b/repository/service_refresh_test.go new file mode 100644 index 000000000..30fcc3789 --- /dev/null +++ b/repository/service_refresh_test.go @@ -0,0 +1,19 @@ +package repository + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_SyncChanges_RefreshDisabled(t *testing.T) { + service := &Service{ + refreshDisabled: true, + } + + changed, err := service.SyncChanges(context.Background()) + require.NoError(t, err) + assert.False(t, changed) +} diff --git a/repository/shape/column/detector.go b/repository/shape/column/detector.go index 79b6c8d1c..72ea292f3 100644 --- a/repository/shape/column/detector.go +++ b/repository/shape/column/detector.go @@ -8,6 +8,7 @@ import ( "github.com/viant/datly/view" viewcolumn "github.com/viant/datly/view/column" + "github.com/viant/datly/view/state" "github.com/viant/sqlparser" "github.com/viant/sqlx/io" ) @@ -31,7 +32,17 @@ func (d *Detector) Resolve(ctx context.Context, resource *view.Resource, aView * } base := columnsFromSchema(aView) - if !usesWildcard(aView) { + // If columns are placeholders (col_1, col_2, etc.) from static inference, treat as no columns + if allPlaceholderColumns(aView.Columns) { + base = nil + } + if explicit := explicitProjectedSubqueryColumns(aView); len(explicit) > 0 { + if len(base) == 0 { + return explicit, nil + } + return mergePreservingOrder(base, explicit), nil + } + if !needsDiscovery(aView) && len(base) > 0 { return base, nil } @@ -54,14 +65,591 @@ func (d *Detector) detect(ctx context.Context, resource *view.Resource, aView *v if err != nil { return nil, fmt.Errorf("shape column detector: failed to open db for view %s: %w", aView.Name, err) } - query := sourceSQL(aView) - sqlColumns, err := viewcolumn.Discover(ctx, db, aView.Table, query) + query := discoverySQL(aView, resource) + table := resolveDiscoveryTable(aView, resource, sourceSQL(aView)) + sqlColumns, err := viewcolumn.Discover(ctx, db, table, query) if err != nil { - return nil, fmt.Errorf("shape column detector: discover failed for view %s: %w", aView.Name, err) + return nil, fmt.Errorf("shape column detector: discover failed for view %s (query=%q, table=%q): %w", aView.Name, query, table, err) } return view.NewColumns(sqlColumns, aView.ColumnsConfig), nil } +// discoverySQL returns SQL suitable for column discovery. +// Strategy: +// 1. Strip template variables ($var, #if...#end, ${expr}) +// 2. Inject 1=0 into every SELECT in the query (CTEs, UNIONs, subqueries) +// This ensures zero rows scanned — safe for BigQuery (no full scan cost) +// 3. Fall back to table name if parsing/falsification fails +func discoverySQL(aView *view.View, resource *view.Resource) string { + raw := sourceSQL(aView) + if expanded := applyConstValuesForDiscovery(raw, resource); strings.TrimSpace(expanded) != "" { + raw = expanded + } + table := resolveDiscoveryTable(aView, resource, raw) + if raw == "" { + return table + } + // EXCEPT clause is a datly projection extension; table fallback is safest. + if table != "" && hasExceptClause(raw) { + return table + } + // Template SQL with wildcard fallback to table metadata. For explicit projection + // we still derive columns from SQL (after template stripping) to avoid widening + // contract to the whole table. + if table != "" && hasTemplateVariables(raw) && usesWildcard(aView) { + return table + } + // For clean SQL without templates, try to falsify for column type inference + cleaned := strings.TrimSpace(raw) + if hasTemplateVariables(cleaned) { + cleaned = strings.TrimSpace(stripTemplateVariables(cleaned)) + } + if cleaned == "" || !strings.Contains(strings.ToLower(cleaned), "select") { + if table != "" { + return table + } + return cleaned + } + if falsified, ok := falsifyQuery(cleaned); ok { + return falsified + } + // Fallback to table + if table != "" { + return table + } + return cleaned +} + +func explicitProjectedSubqueryColumns(aView *view.View) view.Columns { + if aView == nil || !usesWildcard(aView) { + return nil + } + sql := strings.TrimSpace(sourceSQL(aView)) + if sql == "" { + return nil + } + queryNode, err := sqlparser.ParseQuery(sql) + if err != nil || queryNode == nil || !queryNode.List.IsStarExpr() || queryNode.From.X == nil { + return nil + } + fromExpr := strings.TrimSpace(sqlparser.Stringify(queryNode.From.X)) + if fromExpr == "" { + return nil + } + fromExpr = strings.TrimSpace(strings.TrimPrefix(fromExpr, "(")) + fromExpr = strings.TrimSpace(strings.TrimSuffix(fromExpr, ")")) + if !strings.Contains(strings.ToLower(fromExpr), "select") { + return nil + } + innerQuery, err := sqlparser.ParseQuery(fromExpr) + if err != nil || innerQuery == nil { + return nil + } + columns := sqlparser.NewColumns(innerQuery.List) + if len(columns) == 0 || columns.IsStarExpr() { + return nil + } + normalizeExplicitProjectedColumnTypes(columns) + return view.NewColumns(columns, aView.ColumnsConfig) +} + +func normalizeExplicitProjectedColumnTypes(columns sqlparser.Columns) { + for _, column := range columns { + if column == nil || strings.TrimSpace(column.Type) != "" { + continue + } + expression := strings.TrimSpace(column.Expression) + trimmed := strings.TrimSpace(strings.Trim(expression, "()")) + switch { + case trimmed == "": + continue + case trimmed == "true" || trimmed == "false": + column.Type = "bool" + case isIntegerLiteral(trimmed): + column.Type = "int" + case isFloatLiteral(trimmed): + column.Type = "float64" + case isQuotedLiteral(trimmed): + column.Type = "string" + } + } +} + +func isIntegerLiteral(value string) bool { + if value == "" { + return false + } + for i, ch := range value { + if i == 0 && (ch == '-' || ch == '+') { + if len(value) == 1 { + return false + } + continue + } + if ch < '0' || ch > '9' { + return false + } + } + return true +} + +func isFloatLiteral(value string) bool { + if value == "" || strings.Count(value, ".") != 1 { + return false + } + value = strings.ReplaceAll(value, ".", "") + return isIntegerLiteral(value) +} + +func isQuotedLiteral(value string) bool { + return len(value) >= 2 && ((value[0] == '\'' && value[len(value)-1] == '\'') || (value[0] == '"' && value[len(value)-1] == '"')) +} + +func resolveDiscoveryTable(aView *view.View, resource *view.Resource, rawSQL string) string { + table := "" + if aView != nil { + table = strings.TrimSpace(aView.Table) + } + if expanded := strings.TrimSpace(applyConstValuesForDiscovery(table, resource)); expanded != "" { + table = expanded + } + table = normalizeDiscoveryTable(table) + if table == "" { + table = inferDiscoveryTable(rawSQL) + } + return table +} + +func applyConstValuesForDiscovery(sql string, resource *view.Resource) string { + if strings.TrimSpace(sql) == "" || resource == nil || len(resource.Parameters) == 0 { + return sql + } + consts := map[string]string{} + for _, item := range resource.Parameters { + if item == nil || item.In == nil || item.In.Kind != state.KindConst { + continue + } + name := strings.TrimSpace(item.Name) + if name == "" { + name = strings.TrimSpace(item.In.Name) + } + if name == "" || item.Value == nil { + continue + } + consts[name] = fmt.Sprintf("%v", item.Value) + } + if len(consts) == 0 { + return sql + } + var b strings.Builder + b.Grow(len(sql)) + for i := 0; i < len(sql); { + if sql[i] != '$' { + b.WriteByte(sql[i]) + i++ + continue + } + if i+1 < len(sql) && sql[i+1] == '{' { + end := i + 2 + for end < len(sql) && sql[end] != '}' { + end++ + } + if end >= len(sql) { + b.WriteString(sql[i:]) + break + } + expr := strings.TrimSpace(sql[i+2 : end]) + if value, ok := constFromExpr(expr, consts); ok { + b.WriteString(formatConstForDiscovery(value)) + } else { + b.WriteString(sql[i : end+1]) + } + i = end + 1 + continue + } + end := i + 1 + for end < len(sql) && (isIdentPart(sql[end]) || sql[end] == '.') { + end++ + } + expr := sql[i+1 : end] + if value, ok := constFromExpr(expr, consts); ok { + b.WriteString(formatConstForDiscovery(value)) + } else { + b.WriteString(sql[i:end]) + } + i = end + } + return b.String() +} + +func constFromExpr(expr string, consts map[string]string) (string, bool) { + if expr == "" { + return "", false + } + if strings.HasPrefix(expr, "Unsafe.") { + expr = strings.TrimPrefix(expr, "Unsafe.") + } + if value, ok := consts[expr]; ok { + return value, true + } + for name, value := range consts { + if strings.EqualFold(name, expr) { + return value, true + } + } + return "", false +} + +func formatConstForDiscovery(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "''" + } + for i := 0; i < len(value); i++ { + ch := value[i] + if !(isIdentPart(ch) || ch == '.' || ch == '`') { + escaped := strings.ReplaceAll(value, "'", "''") + return "'" + escaped + "'" + } + } + return value +} + +func normalizeDiscoveryTable(table string) string { + trimmed := strings.TrimSpace(strings.Trim(table, "`\"")) + if strings.HasPrefix(trimmed, "${Unsafe.") && strings.HasSuffix(trimmed, "}") { + trimmed = strings.TrimSuffix(strings.TrimPrefix(trimmed, "${Unsafe."), "}") + } + if strings.HasPrefix(trimmed, "$Unsafe.") { + trimmed = strings.TrimPrefix(trimmed, "$Unsafe.") + } + trimmed = strings.TrimSpace(trimmed) + if trimmed == "" { + return table + } + for i := 0; i < len(trimmed); i++ { + ch := trimmed[i] + if !(isIdentPart(ch) || ch == '.' || ch == '`' || ch == '"') { + return table + } + } + return strings.Trim(trimmed, "`\"") +} + +func inferDiscoveryTable(sql string) string { + lower := strings.ToLower(sql) + idx := strings.Index(lower, " from ") + if idx == -1 { + if token := findUnsafeTableToken(sql); token != "" { + return token + } + return "" + } + pos := idx + len(" from ") + for pos < len(sql) && (sql[pos] == ' ' || sql[pos] == '\t' || sql[pos] == '\n' || sql[pos] == '\r') { + pos++ + } + if pos >= len(sql) { + return "" + } + if strings.HasPrefix(sql[pos:], "${Unsafe.") { + end := strings.Index(sql[pos:], "}") + if end == -1 { + return "" + } + token := strings.TrimSpace(sql[pos+len("${Unsafe.") : pos+end]) + if token == "" { + return "" + } + return token + } + if strings.HasPrefix(sql[pos:], "$Unsafe.") { + start := pos + len("$Unsafe.") + end := start + for end < len(sql) && (isIdentPart(sql[end]) || sql[end] == '.') { + end++ + } + return strings.TrimSpace(sql[start:end]) + } + if sql[pos] == '(' { + depth := 1 + end := pos + 1 + for end < len(sql) && depth > 0 { + switch sql[end] { + case '(': + depth++ + case ')': + depth-- + } + end++ + } + if end > pos+1 { + if nested := inferDiscoveryTable(sql[pos+1 : end-1]); nested != "" { + return nested + } + } + if token := findUnsafeTableToken(sql[pos:]); token != "" { + return token + } + return "" + } + end := pos + for end < len(sql) && (isIdentPart(sql[end]) || sql[end] == '.' || sql[end] == '`' || sql[end] == '"') { + end++ + } + return strings.TrimSpace(strings.Trim(sql[pos:end], "`\"")) +} + +func findUnsafeTableToken(sql string) string { + if idx := strings.Index(sql, "${Unsafe."); idx != -1 { + start := idx + len("${Unsafe.") + end := strings.Index(sql[start:], "}") + if end != -1 { + return strings.TrimSpace(sql[start : start+end]) + } + } + if idx := strings.Index(sql, "$Unsafe."); idx != -1 { + start := idx + len("$Unsafe.") + end := start + for end < len(sql) && (isIdentPart(sql[end]) || sql[end] == '.') { + end++ + } + return strings.TrimSpace(sql[start:end]) + } + return "" +} + +func removeExceptClauses(sql string) string { + // Remove "EXCEPT col1, col2" patterns — these are datly-specific + // Simple approach: remove " EXCEPT (, )*" + result := sql + for { + lower := strings.ToLower(result) + idx := strings.Index(lower, " except ") + if idx == -1 { + break + } + // Find end of EXCEPT clause (next keyword or end of identifier list) + end := idx + len(" except ") + for end < len(result) && (isIdentPart(result[end]) || result[end] == ',' || result[end] == ' ') { + end++ + } + result = result[:idx] + result[end:] + } + return result +} + +func hasTemplateVariables(sql string) bool { + for i := 0; i < len(sql)-1; i++ { + if sql[i] == '$' && isIdentStart(sql[i+1]) { + return true + } + if sql[i] == '#' && (sql[i+1] == 'i' || sql[i+1] == 'f' || sql[i+1] == 'e' || sql[i+1] == 's') { + return true + } + if sql[i] == '$' && sql[i+1] == '{' { + return true + } + } + return false +} + +func hasExceptClause(sql string) bool { + lower := strings.ToLower(sql) + return strings.Contains(lower, " except ") +} + +// needsDiscovery returns true if the view SQL uses wildcards or has no explicit columns. +func needsDiscovery(aView *view.View) bool { + if aView == nil { + return false + } + if len(aView.Columns) == 0 { + return true + } + if allPlaceholderColumns(aView.Columns) { + return true + } + return usesWildcard(aView) +} + +// stripTemplateVariables removes velocity/velty template constructs from SQL +// so it can be parsed and executed for column discovery. +// Handles: $variable, ${expression}, #if...#end, #foreach...#end, #set(...) +func stripTemplateVariables(sql string) string { + var b strings.Builder + b.Grow(len(sql)) + i := 0 + for i < len(sql) { + // Handle # directives: #if, #foreach, #set, #end, #else, #elseif + if sql[i] == '#' && i+1 < len(sql) { + directive := matchDirective(sql, i) + if directive != "" { + // Skip entire directive line/block + end := skipDirective(sql, i, directive) + // Replace with space to preserve SQL structure + b.WriteByte(' ') + i = end + continue + } + } + // Handle $ variables: $name, $name.method(...), ${expression} + if sql[i] == '$' && i+1 < len(sql) { + next := sql[i+1] + if next == '{' { + // ${...} expression — find matching } + depth := 1 + j := i + 2 + for j < len(sql) && depth > 0 { + if sql[j] == '{' { + depth++ + } else if sql[j] == '}' { + depth-- + } + j++ + } + // Replace with empty string or placeholder + b.WriteString("''") + i = j + continue + } + if isIdentStart(next) { + // $varName or $varName.method(...) + j := i + 1 + for j < len(sql) && isIdentPart(sql[j]) { + j++ + } + hasMethodCall := false + methodExpr := "" + // Skip .method() chains + for j < len(sql) && sql[j] == '.' { + methodStart := j + j++ + for j < len(sql) && isIdentPart(sql[j]) { + j++ + } + if j < len(sql) && sql[j] == '(' { + hasMethodCall = true + methodExpr = sql[methodStart:j] + depth := 1 + j++ + for j < len(sql) && depth > 0 { + if sql[j] == '(' { + depth++ + } else if sql[j] == ')' { + depth-- + } + j++ + } + } + } + if hasMethodCall { + if strings.EqualFold(methodExpr, ".AppendBinding") { + b.WriteString("''") + } else { + b.WriteString("") + } + } else { + b.WriteString("''") + } + i = j + continue + } + } + b.WriteByte(sql[i]) + i++ + } + return b.String() +} + +func matchDirective(sql string, pos int) string { + directives := []string{"#foreach", "#if", "#elseif", "#else", "#end", "#set", "#settings", "#setting", "#define", "#package", "#import"} + remaining := sql[pos:] + for _, d := range directives { + if len(remaining) >= len(d) && strings.EqualFold(remaining[:len(d)], d) { + if len(remaining) == len(d) || !isIdentPart(remaining[len(d)]) { + return d + } + } + } + return "" +} + +func skipDirective(sql string, pos int, directive string) int { + switch { + case directive == "#set" || directive == "#settings" || directive == "#setting" || directive == "#define": + // Skip to end of line or matching paren + j := pos + len(directive) + for j < len(sql) && (sql[j] == ' ' || sql[j] == '\t') { + j++ + } + if j < len(sql) && sql[j] == '(' { + depth := 1 + j++ + for j < len(sql) && depth > 0 { + if sql[j] == '(' { + depth++ + } else if sql[j] == ')' { + depth-- + } + j++ + } + return j + } + // Skip to end of line + for j < len(sql) && sql[j] != '\n' { + j++ + } + if j < len(sql) { + j++ + } + return j + case directive == "#foreach" || directive == "#if": + // Skip to matching #end + j := pos + len(directive) + depth := 1 + for j < len(sql) && depth > 0 { + d := matchDirective(sql, j) + if d == "#if" || d == "#foreach" { + depth++ + j += len(d) + } else if d == "#end" { + depth-- + j += len(d) + } else { + j++ + } + } + return j + default: + // #else, #elseif, #end, #package, #import — skip to end of line + j := pos + len(directive) + for j < len(sql) && sql[j] != '\n' { + j++ + } + if j < len(sql) { + j++ + } + return j + } +} + +func allPlaceholderColumns(columns view.Columns) bool { + if len(columns) == 0 { + return false + } + for _, col := range columns { + if col == nil { + continue + } + name := strings.ToLower(col.Name) + if !strings.HasPrefix(name, "col_") { + return false + } + } + return true +} + func lookupConnector(ctx context.Context, resource *view.Resource, aView *view.View) (*view.Connector, error) { if resource == nil { return nil, fmt.Errorf("shape column detector: missing resource for view %s", aView.Name) @@ -129,13 +717,16 @@ func columnsFromSchema(aView *view.View) view.Columns { } result := make(view.Columns, 0, rType.NumField()) appendSchemaColumns(rType, "", &result) + if allPlaceholderColumns(result) { + return nil + } return result } func appendSchemaColumns(rType reflect.Type, ns string, columns *view.Columns) { for i := 0; i < rType.NumField(); i++ { field := rType.Field(i) - if field.PkgPath != "" { // unexported + if field.PkgPath != "" { continue } if field.Anonymous { @@ -148,12 +739,13 @@ func appendSchemaColumns(rType reflect.Type, ns string, columns *view.Columns) { } continue } - + if shouldSkipSchemaField(field) { + continue + } tag := io.ParseTag(field.Tag) if tag != nil && tag.Transient { continue } - name := field.Name if tag != nil && tag.Column != "" { name = tag.Column @@ -163,7 +755,6 @@ func appendSchemaColumns(rType reflect.Type, ns string, columns *view.Columns) { } else if ns != "" { name = ns + name } - columnType := field.Type nullable := false if columnType.Kind() == reflect.Ptr { @@ -174,6 +765,20 @@ func appendSchemaColumns(rType reflect.Type, ns string, columns *view.Columns) { } } +func shouldSkipSchemaField(field reflect.StructField) bool { + if field.Name == "-" { + return true + } + rawTag := string(field.Tag) + if strings.Contains(rawTag, `view:"`) || strings.Contains(rawTag, `on:"`) { + return true + } + if strings.Contains(rawTag, `sqlx:"-"`) { + return true + } + return false +} + func mergePreservingOrder(base, discovered view.Columns) view.Columns { if len(base) == 0 { return discovered @@ -195,7 +800,6 @@ func mergePreservingOrder(base, discovered view.Columns) view.Columns { } if fresh, ok := seen[strings.ToLower(item.Name)]; ok { delete(seen, strings.ToLower(item.Name)) - // Keep schema name/order but refresh discovered metadata. item.DataType = firstNonEmpty(fresh.DataType, item.DataType) item.SetColumnType(firstType(fresh.ColumnType(), item.ColumnType())) item.Nullable = fresh.Nullable @@ -218,6 +822,14 @@ func mergePreservingOrder(base, discovered view.Columns) view.Columns { return result } +func isIdentStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isIdentPart(ch byte) bool { + return isIdentStart(ch) || (ch >= '0' && ch <= '9') +} + func firstNonEmpty(values ...string) string { for _, value := range values { if strings.TrimSpace(value) != "" { diff --git a/repository/shape/column/detector_test.go b/repository/shape/column/detector_test.go index cfc834b11..f04b17d49 100644 --- a/repository/shape/column/detector_test.go +++ b/repository/shape/column/detector_test.go @@ -14,6 +14,21 @@ type sampleOrder struct { Name string `sqlx:"name=NAME"` } +type sampleSemanticRoot struct { + ID int `sqlx:"ID"` + Products []*sampleChildView `view:",table=PRODUCT" on:"Id:ID=VendorId:VENDOR_ID" sql:"uri=vendor/products.sql"` + Cities []*sampleChildView `view:",table=CITY" on:"Id:ID=DistrictId:DISTRICT_ID"` + Ignored string `sqlx:"-"` +} + +type sampleChildView struct { + VendorID int `sqlx:"VENDOR_ID"` +} + +func stringPtr(value string) *string { + return &value +} + func TestUsesWildcard(t *testing.T) { tests := []struct { name string @@ -39,6 +54,13 @@ func TestColumnsFromSchema_Order(t *testing.T) { require.Equal(t, "NAME", cols[1].Name) } +func TestColumnsFromSchema_SkipsSemanticRelationFields(t *testing.T) { + aView := &view.View{Schema: state.NewSchema(reflect.TypeOf(sampleSemanticRoot{}), state.WithMany())} + cols := columnsFromSchema(aView) + require.Len(t, cols, 1) + require.Equal(t, "ID", cols[0].Name) +} + func TestMergePreservingOrder_AppendsNewDetectedColumns(t *testing.T) { base := view.Columns{ view.NewColumn("VENDOR_ID", "int", reflect.TypeOf(int(0)), false), @@ -57,3 +79,68 @@ func TestMergePreservingOrder_AppendsNewDetectedColumns(t *testing.T) { require.Equal(t, "bigint", merged[0].DataType) require.Equal(t, "text", merged[1].DataType) } + +func TestApplyConstValuesForDiscovery(t *testing.T) { + resource := view.EmptyResource() + resource.AddParameters( + &state.Parameter{Name: "Vendor", In: state.NewConstLocation("Vendor"), Value: "VENDOR"}, + &state.Parameter{Name: "Product", In: state.NewConstLocation("Product"), Value: "PRODUCT"}, + ) + sql := `SELECT vendor.*, products.* FROM (SELECT * FROM $Vendor t) vendor JOIN (SELECT * FROM ${Unsafe.Product} p) products ON products.VENDOR_ID = vendor.ID` + got := applyConstValuesForDiscovery(sql, resource) + require.Contains(t, got, "FROM (SELECT * FROM VENDOR t)") + require.Contains(t, got, "JOIN (SELECT * FROM PRODUCT p)") + require.NotContains(t, got, "$Vendor") + require.NotContains(t, got, "${Unsafe.Product}") +} + +func TestDiscoverySQL_ResolvesConstTableFallback(t *testing.T) { + resource := view.EmptyResource() + resource.AddParameters( + &state.Parameter{Name: "Vendor", In: state.NewConstLocation("Vendor"), Value: "VENDOR"}, + ) + aView := &view.View{ + Table: "${Unsafe.Vendor}", + Template: view.NewTemplate("SELECT * FROM ${Unsafe.Vendor} t WHERE t.ID = $criteria.AppendBinding($Unsafe.VendorID)"), + } + got := discoverySQL(aView, resource) + require.Equal(t, "VENDOR", got) +} + +func TestNormalizeDiscoveryTable_TemplateUnsafe(t *testing.T) { + require.Equal(t, "Vendor", normalizeDiscoveryTable("${Unsafe.Vendor}")) + require.Equal(t, "Product", normalizeDiscoveryTable("$Unsafe.Product")) +} + +func TestDiscoverySQL_TemplateTableWithoutConstParameter_UsesNormalizedTable(t *testing.T) { + aView := &view.View{ + Table: "${Unsafe.Vendor}", + Template: view.NewTemplate("SELECT * FROM ${Unsafe.Vendor} t WHERE t.ID IN ($criteria.AppendBinding($Unsafe.vendorIDs))"), + } + got := discoverySQL(aView, view.EmptyResource()) + require.Equal(t, "Vendor", got) +} + +func TestExplicitProjectedSubqueryColumns_UsesInnerProjection(t *testing.T) { + aView := &view.View{ + Template: view.NewTemplate("SELECT * FROM (SELECT (1) AS IS_ACTIVE, (3) AS CHANNEL, CAST($criteria.AppendBinding($Unsafe.VendorID) AS SIGNED) AS ID) t"), + ColumnsConfig: map[string]*view.ColumnConfig{ + "ID": {Name: "ID", Tag: stringPtr(`internal:"true"`)}, + }, + } + got := explicitProjectedSubqueryColumns(aView) + require.Len(t, got, 3) + require.Equal(t, "IS_ACTIVE", got[0].Name) + require.Equal(t, "int", got[0].DataType) + require.Equal(t, "CHANNEL", got[1].Name) + require.Equal(t, "int", got[1].DataType) + require.Equal(t, "ID", got[2].Name) + require.Equal(t, ` internal:"true"`, got[2].Tag) +} + +func TestInferDiscoveryTable(t *testing.T) { + require.Equal(t, "Vendor", inferDiscoveryTable("SELECT * FROM ${Unsafe.Vendor} t WHERE 1=1")) + require.Equal(t, "Product", inferDiscoveryTable("SELECT * FROM $Unsafe.Product t WHERE 1=1")) + require.Equal(t, "VENDOR", inferDiscoveryTable("SELECT * FROM VENDOR t WHERE 1=1")) + require.Equal(t, "Vendor", inferDiscoveryTable("SELECT vendor.* FROM (SELECT * FROM ${Unsafe.Vendor} t WHERE 1=1) vendor")) +} diff --git a/repository/shape/column/falsify.go b/repository/shape/column/falsify.go new file mode 100644 index 000000000..0f50e2dca --- /dev/null +++ b/repository/shape/column/falsify.go @@ -0,0 +1,151 @@ +package column + +import ( + "strings" + + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/query" +) + +// falsifyQuery parses an SQL string and injects WHERE 1=0 into every SELECT +// in the query tree (outer query, CTEs, UNIONs). This ensures zero rows are +// scanned while preserving the output schema for column type inference. +// +// Returns the rewritten SQL string and true if successful. +// Returns the original SQL and false if parsing fails. +func falsifyQuery(sql string) (string, bool) { + sql = strings.TrimSpace(sql) + if sql == "" { + return sql, false + } + parsed, err := sqlparser.ParseQuery(sql) + if err != nil { + return sql, false + } + falsifySelect(parsed) + // Remove LIMIT/OFFSET from outer query — we want schema only + parsed.Limit = nil + parsed.Offset = nil + result := sqlparser.Stringify(parsed) + if strings.TrimSpace(result) == "" { + return sql, false + } + return result, true +} + +// falsifySelect injects 1=0 into a SELECT and recursively into all nested SELECTs. +func falsifySelect(sel *query.Select) { + if sel == nil { + return + } + // Inject 1=0 into this SELECT's WHERE clause + injectFalsePredicate(sel) + // Process CTE WITH selects + for _, ws := range sel.WithSelects { + if ws != nil && ws.X != nil { + falsifySelect(ws.X) + ws.Raw = "" // Force Stringify to use modified X instead of original Raw + } + } + // Process UNION branches + if sel.Union != nil && sel.Union.X != nil { + falsifySelect(sel.Union.X) + } + // Process subquery in FROM (if it's a nested SELECT) + falsifyFromSubquery(sel) + // Process JOIN subqueries + for _, join := range sel.Joins { + if join != nil { + falsifyJoinSubquery(join) + } + } +} + +// injectFalsePredicate adds 1=0 to the SELECT's WHERE clause. +func injectFalsePredicate(sel *query.Select) { + if sel == nil { + return + } + fp := &expr.Binary{ + X: &expr.Literal{Value: "1"}, + Op: "=", + Y: &expr.Literal{Value: "0"}, + } + if sel.Qualify == nil || sel.Qualify.X == nil { + sel.Qualify = &expr.Qualify{X: fp} + } else { + sel.Qualify = &expr.Qualify{ + X: &expr.Binary{ + X: fp, + Op: "AND", + Y: sel.Qualify.X, + }, + } + } +} + +// falsifyFromSubquery checks if the FROM clause contains a subquery and falsifies it. +func falsifyFromSubquery(sel *query.Select) { + if sel == nil || sel.From.X == nil { + return + } + switch sub := sel.From.X.(type) { + case *expr.Parenthesis: + falsifySubqueryExpr(sub) + case *expr.Raw: + falsifyRawSubquery(sub) + } +} + +func falsifyRawSubquery(raw *expr.Raw) { + if raw == nil { + return + } + text := strings.TrimSpace(raw.Raw) + if text == "" && raw.Unparsed != "" { + text = strings.TrimSpace(raw.Unparsed) + } + // Strip outer parens if present + if len(text) >= 2 && text[0] == '(' && text[len(text)-1] == ')' { + text = text[1 : len(text)-1] + } + if !strings.Contains(strings.ToLower(text), "select") { + return + } + subQuery, err := sqlparser.ParseQuery(text) + if err != nil { + return + } + falsifySelect(subQuery) + rewritten := sqlparser.Stringify(subQuery) + raw.Raw = "(" + rewritten + ")" +} + +// falsifyJoinSubquery checks if a JOIN's WITH clause contains a subquery and falsifies it. +func falsifyJoinSubquery(join *query.Join) { + if join == nil || join.With == nil { + return + } + if sub, ok := join.With.(*expr.Parenthesis); ok { + falsifySubqueryExpr(sub) + } +} + +// falsifySubqueryExpr attempts to parse and falsify a parenthesized subquery expression. +func falsifySubqueryExpr(paren *expr.Parenthesis) { + if paren == nil || paren.X == nil { + return + } + raw := sqlparser.Stringify(paren.X) + if !strings.Contains(strings.ToLower(strings.TrimSpace(raw)), "select") { + return + } + subQuery, err := sqlparser.ParseQuery(raw) + if err != nil { + return + } + falsifySelect(subQuery) + rewritten := sqlparser.Stringify(subQuery) + paren.X = expr.NewRaw(rewritten) +} diff --git a/repository/shape/column/falsify_test.go b/repository/shape/column/falsify_test.go new file mode 100644 index 000000000..fbc769171 --- /dev/null +++ b/repository/shape/column/falsify_test.go @@ -0,0 +1,144 @@ +package column + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFalsifyQuery(t *testing.T) { + tests := []struct { + name string + input string + wantOK bool + assertions func(t *testing.T, result string) + }{ + { + name: "simple SELECT *", + input: "SELECT * FROM orders", + wantOK: true, + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + assert.NotContains(t, strings.ToUpper(result), "LIMIT") + }, + }, + { + name: "SELECT with existing WHERE", + input: "SELECT id, name FROM items WHERE status = 1", + wantOK: true, + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + assert.Contains(t, result, "status") + }, + }, + { + name: "SELECT with LIMIT stripped", + input: "SELECT * FROM items LIMIT 100 OFFSET 50", + wantOK: true, + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + assert.NotContains(t, strings.ToUpper(result), "LIMIT") + assert.NotContains(t, strings.ToUpper(result), "OFFSET") + }, + }, + { + name: "UNION ALL — both branches get 1=0", + input: "SELECT id, name FROM items_a WHERE region = 'us' UNION ALL SELECT id, name FROM items_b WHERE region = 'eu'", + wantOK: true, + assertions: func(t *testing.T, result string) { + count := strings.Count(result, "1 = 0") + assert.GreaterOrEqual(t, count, 2, "both UNION branches should get 1=0") + }, + }, + { + name: "CTE — all CTEs and outer get 1=0", + input: `WITH metrics AS ( + SELECT category, SUM(amount) AS total + FROM transactions + GROUP BY category +), +ranked AS ( + SELECT *, ROW_NUMBER() OVER (ORDER BY total DESC) AS rn + FROM metrics +) +SELECT * FROM ranked WHERE rn <= 10`, + wantOK: true, + assertions: func(t *testing.T, result string) { + count := strings.Count(result, "1 = 0") + assert.GreaterOrEqual(t, count, 3, "each CTE + outer should get 1=0, got %d", count) + assert.NotContains(t, strings.ToUpper(result), "LIMIT") + }, + }, + { + name: "CTE with UNION inside", + input: `WITH combined AS ( + SELECT id, name FROM items_a + UNION ALL + SELECT id, name FROM items_b +) +SELECT * FROM combined`, + wantOK: true, + assertions: func(t *testing.T, result string) { + count := strings.Count(result, "1 = 0") + assert.GreaterOrEqual(t, count, 3, "CTE branches + outer should all get 1=0") + }, + }, + { + name: "JOIN query — outer gets 1=0", + input: "SELECT a.id, b.name FROM orders a JOIN items b ON a.item_id = b.id", + wantOK: true, + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + }, + }, + { + name: "subquery in FROM — both get 1=0", + input: "SELECT t.* FROM (SELECT id, name FROM items WHERE active = 1) t", + wantOK: true, + assertions: func(t *testing.T, result string) { + count := strings.Count(result, "1 = 0") + assert.GreaterOrEqual(t, count, 2, "outer + subquery should get 1=0") + }, + }, + { + name: "empty SQL", + input: "", + wantOK: false, + }, + { + name: "non-SELECT statement", + input: "INSERT INTO items VALUES (1, 'test')", + wantOK: true, // parser may still parse it; falsify is best-effort + }, + { + name: "GROUP BY preserved", + input: "SELECT category, COUNT(*) AS cnt FROM items GROUP BY category", + wantOK: true, + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + assert.Contains(t, strings.ToUpper(result), "GROUP BY") + }, + }, + { + name: "ORDER BY preserved", + input: "SELECT * FROM items ORDER BY name", + wantOK: true, + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, ok := falsifyQuery(tt.input) + assert.Equal(t, tt.wantOK, ok) + if ok && tt.assertions != nil { + require.NotEmpty(t, result) + tt.assertions(t, result) + } + }) + } +} diff --git a/repository/shape/column/strip_test.go b/repository/shape/column/strip_test.go new file mode 100644 index 000000000..df847559e --- /dev/null +++ b/repository/shape/column/strip_test.go @@ -0,0 +1,318 @@ +package column + +import ( + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestStripTemplateVariables(t *testing.T) { + tests := []struct { + name string + input string + expect string + }{ + { + name: "no templates", + input: "SELECT * FROM VENDOR WHERE ID = 1", + expect: "SELECT * FROM VENDOR WHERE ID = 1", + }, + { + name: "simple variable", + input: "SELECT * FROM VENDOR WHERE ID = $vendorID", + expect: "SELECT * FROM VENDOR WHERE ID = ''", + }, + { + name: "variable with dot method", + input: "SELECT * FROM VENDOR WHERE ID IN ($Unsafe.vendorIDs)", + expect: "SELECT * FROM VENDOR WHERE ID IN ('')", + }, + { + name: "variable with method call", + input: "SELECT * FROM PRODUCT WHERE 1=1 $View.ParentJoinOn(\"AND\",\"VENDOR_ID\")", + expect: "SELECT * FROM PRODUCT WHERE 1=1 ", + }, + { + name: "criteria binding", + input: "SELECT * FROM VENDOR t WHERE t.ID IN ($criteria.AppendBinding($Unsafe.vendorIDs))", + expect: "SELECT * FROM VENDOR t WHERE t.ID IN ('')", + }, + { + name: "expression in braces", + input: "SELECT * FROM VENDOR WHERE ${predicate.Build(\"AND\")}", + expect: "SELECT * FROM VENDOR WHERE ''", + }, + { + name: "if directive", + input: "SELECT * FROM PRODUCT WHERE 1=1 #if($vendorID < 0) AND 1=2 #end", + expect: "SELECT * FROM PRODUCT WHERE 1=1 ", + }, + { + name: "foreach directive", + input: "#foreach($item in $items) INSERT INTO T VALUES($item.ID) #end", + expect: " ", + }, + { + name: "set directive with parens", + input: "#set($x = 1)\nSELECT * FROM T", + expect: " \nSELECT * FROM T", + }, + { + name: "mixed templates and SQL", + input: "SELECT vendor.*, products.* FROM (SELECT * FROM VENDOR t) vendor JOIN (SELECT * FROM PRODUCT t WHERE 1=1 ${predicate.Builder().CombineOr($predicate.FilterGroup(0, \"AND\")).Build(\"AND\")}) products ON products.VENDOR_ID = vendor.ID", + expect: "SELECT vendor.*, products.* FROM (SELECT * FROM VENDOR t) vendor JOIN (SELECT * FROM PRODUCT t WHERE 1=1 '') products ON products.VENDOR_ID = vendor.ID", + }, + { + name: "UNION ALL with templates", + input: "SELECT ID, NAME, VENDOR_ID FROM PRODUCT t WHERE 1=1 $View.ParentJoinOn(\"AND\",\"VENDOR_ID\") UNION ALL SELECT ID, NAME, VENDOR_ID FROM PRODUCT t WHERE 1=1 $View.ParentJoinOn(\"AND\",\"VENDOR_ID\")", + expect: "SELECT ID, NAME, VENDOR_ID FROM PRODUCT t WHERE 1=1 UNION ALL SELECT ID, NAME, VENDOR_ID FROM PRODUCT t WHERE 1=1 ", + }, + { + name: "nested if", + input: "SELECT * FROM T WHERE 1=1 #if($a > 0) AND A=$a #if($b > 0) AND B=$b #end #end", + expect: "SELECT * FROM T WHERE 1=1 ", + }, + { + name: "const variable substitution", + input: "SELECT * FROM $Vendor t WHERE t.ID IN ($vendorIDs)", + expect: "SELECT * FROM '' t WHERE t.ID IN ('')", + }, + { + name: "settings directive at top", + input: "#setting($_ = $route('/api/v1/test', 'GET'))\nSELECT * FROM T", + expect: " \nSELECT * FROM T", + }, + { + name: "package and import directives", + input: "#package('dev/vendor')\n#import('pkg', 'github.com/acme/pkg')\nSELECT * FROM T", + expect: " SELECT * FROM T", + }, + { + name: "complex predicate builder", + input: "WHERE ${predicate.Builder().CombineOr($predicate.FilterGroup(0, \"AND\")).Build(\"AND\")}", + expect: "WHERE ''", + }, + { + name: "dollar at end", + input: "SELECT * FROM T WHERE X = $", + expect: "SELECT * FROM T WHERE X = $", + }, + { + name: "dollar number (not a variable)", + input: "SELECT * FROM T WHERE X = $1", + expect: "SELECT * FROM T WHERE X = $1", + }, + { + name: "cast expression", + input: "SELECT CAST($Jwt.FirstName AS CHAR) AS FIRST_NAME FROM T", + expect: "SELECT CAST('' AS CHAR) AS FIRST_NAME FROM T", + }, + { + name: "logger and unsafe", + input: "#foreach($rec in $Unsafe.Records) UPDATE T SET V=$rec.Value WHERE ID=$rec.ID; #end", + expect: " ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripTemplateVariables(tt.input) + assert.Equal(t, tt.expect, got) + }) + } +} + +func TestStripTemplateVariables_CTE(t *testing.T) { + tests := []struct { + name string + input string + expect string + }{ + { + name: "CTE with template params", + input: `WITH params AS ( + SELECT DATE_SUB(CURRENT_DATE(), INTERVAL $EndDayInterval DAY) AS end_date, + CAST(GREATEST($Page, 1) AS INT64) AS page_number +), +perf AS ( + SELECT p.agency_id, SUM(p.impressions) AS imps + FROM fact_performance p + JOIN params prm ON TRUE + WHERE p.event_date BETWEEN prm.start_date AND prm.end_date + ${predicate.Builder().CombineOr($predicate.FilterGroup(0, "AND")).Build("AND")} + GROUP BY 1 +) +SELECT v.* FROM perf v ORDER BY v.agency_id`, + expect: `WITH params AS ( + SELECT DATE_SUB(CURRENT_DATE(), INTERVAL '' DAY) AS end_date, + CAST(GREATEST('', 1) AS INT64) AS page_number +), +perf AS ( + SELECT p.agency_id, SUM(p.impressions) AS imps + FROM fact_performance p + JOIN params prm ON TRUE + WHERE p.event_date BETWEEN prm.start_date AND prm.end_date + '' + GROUP BY 1 +) +SELECT v.* FROM perf v ORDER BY v.agency_id`, + }, + { + name: "CTE with backtick tables (BigQuery)", + input: "WITH data AS (SELECT * FROM `project.dataset.table` t WHERE t.ID = $id) SELECT * FROM data", + expect: "WITH data AS (SELECT * FROM `project.dataset.table` t WHERE t.ID = '') SELECT * FROM data", + }, + { + name: "UNION ALL in CTE", + input: "WITH combined AS (SELECT * FROM T1 WHERE ID = $a UNION ALL SELECT * FROM T2 WHERE ID = $b) SELECT * FROM combined", + expect: "WITH combined AS (SELECT * FROM T1 WHERE ID = '' UNION ALL SELECT * FROM T2 WHERE ID = '') SELECT * FROM combined", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripTemplateVariables(tt.input) + assert.Equal(t, tt.expect, got) + }) + } +} + +// TestDiscoverySQL_Strategy documents the column discovery strategy for different SQL patterns. +// BigQuery CTEs cannot use WHERE 1=0 (full cost incurred), so the strategy should be: +// - Simple SELECT * FROM table → use table metadata (INFORMATION_SCHEMA or SELECT * WHERE 1=0) +// - SELECT with explicit columns → column names from AST, types from table metadata +// - CTE/WITH queries → parse final SELECT, resolve CTE chain to source tables, use metadata +// - SQL with velocity templates → strip templates, then apply above rules +func TestDiscoverySQL_Strategy(t *testing.T) { + tests := []struct { + name string + table string + sql string + expected string + desc string + assertions func(t *testing.T, result string) + }{ + { + name: "wildcard with templates — uses table fallback", + table: "VENDOR", + sql: "SELECT * FROM VENDOR t WHERE t.ID = $vendorID", + expected: "VENDOR", + desc: "template variables → table fallback (safe for all backends)", + }, + { + name: "explicit projection with templates — preserves projection", + table: "VENDOR", + sql: "SELECT ID FROM VENDOR t WHERE t.ID = $VendorID", + desc: "explicit select list should not widen to table columns", + assertions: func(t *testing.T, result string) { + assert.NotEqual(t, "VENDOR", result) + assert.Contains(t, strings.ToUpper(result), "SELECT ID") + }, + }, + { + name: "wildcard with EXCEPT — uses table fallback", + table: "VENDOR", + sql: "SELECT vendor.* EXCEPT VENDOR_ID FROM VENDOR vendor", + expected: "VENDOR", + desc: "EXCEPT clause → table fallback (EXCEPT is datly extension)", + }, + { + name: "clean explicit SQL gets falsified", + table: "VENDOR", + sql: "SELECT ID, NAME FROM VENDOR WHERE 1=1", + desc: "clean SQL → falsified with 1=0 injected", + assertions: func(t *testing.T, result string) { + assert.Contains(t, result, "1 = 0") + assert.Contains(t, strings.ToUpper(result), "ID") + assert.Contains(t, strings.ToUpper(result), "NAME") + }, + }, + { + name: "empty SQL uses table", + table: "VENDOR", + sql: "", + expected: "VENDOR", + desc: "no SQL → use table name", + }, + { + name: "CTE with templates — uses table fallback", + table: "VENDOR", + sql: "WITH cte AS (SELECT * FROM VENDOR WHERE ID = $id) SELECT * FROM cte", + expected: "VENDOR", + desc: "CTE with templates → table fallback (safe for BigQuery)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := &view.View{ + Name: "test", + Table: tt.table, + } + if tt.sql != "" { + v.Template = &view.Template{Source: tt.sql} + } + got := discoverySQL(v, nil) + if tt.assertions != nil { + tt.assertions(t, got) + } else { + assert.Equal(t, tt.expected, got, tt.desc) + } + }) + } +} + +func TestAllPlaceholderColumns(t *testing.T) { + tests := []struct { + name string + names []string + expect bool + }{ + {"empty", nil, false}, + {"real columns", []string{"ID", "NAME"}, false}, + {"all placeholders", []string{"col_1", "col_2"}, true}, + {"mixed", []string{"col_1", "NAME"}, false}, + {"single placeholder", []string{"col_1"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cols view.Columns + for _, n := range tt.names { + cols = append(cols, &view.Column{Name: n}) + } + assert.Equal(t, tt.expect, allPlaceholderColumns(cols)) + }) + } +} + +func TestNeedsDiscovery(t *testing.T) { + tests := []struct { + name string + view *view.View + expect bool + }{ + {"nil view", nil, false}, + {"no columns", &view.View{Name: "t"}, true}, + {"placeholder columns", &view.View{Name: "t", Columns: view.Columns{&view.Column{Name: "col_1"}}}, true}, + {"real columns no wildcard", &view.View{Name: "t", Columns: view.Columns{&view.Column{Name: "ID"}}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expect, needsDiscovery(tt.view)) + }) + } +} + +type placeholderSchemaRow struct { + Col1 string `sqlx:"name=col_1"` + Col2 string `sqlx:"name=col_2"` +} + +func TestColumnsFromSchema_IgnoresPlaceholderTypes(t *testing.T) { + aView := &view.View{ + Schema: state.NewSchema(reflect.TypeOf([]*placeholderSchemaRow{}), state.WithMany()), + } + assert.Nil(t, columnsFromSchema(aView)) +} diff --git a/repository/shape/compile/COMPONENT_CONTRACT_PARITY.md b/repository/shape/compile/COMPONENT_CONTRACT_PARITY.md new file mode 100644 index 000000000..d198dfe73 --- /dev/null +++ b/repository/shape/compile/COMPONENT_CONTRACT_PARITY.md @@ -0,0 +1,99 @@ +# Component Contract Parity Target (Shape) + +This document defines the target behavior for cross-component contract discovery in `repository/shape`. + +Scope: + +- Applies to DQL compile flow (`compile -> plan -> load`) +- Does not depend on `internal/*` packages +- Defines observable behavior and acceptance criteria + +## Problem Statement + +When DQL declares component dependencies (for example via component-typed state declarations), shape compile should produce component-facing IR that is functionally equivalent to translator contract/signature resolution for: + +- route reference normalization +- output schema/type resolution for component states +- dependent type propagation +- deterministic diagnostics + +Today, shape produces useful plan artifacts for views/states but component contract resolution parity is incomplete. + +## Target Semantics + +### 1) Reference Forms + +A component reference MUST support these forms: + +- Relative: `../acl/auth` +- Method-qualified absolute route: `GET:/v1/api/platform/acl/auth` +- Absolute route without method: `/v1/api/platform/acl/auth` (defaults to `GET`) + +Normalization target: + +- Stable route identity represented as `method + uri` (default method `GET`) +- Namespace/path derivation remains deterministic for file-layout lookups + +### 2) State Enrichment + +For each `plan.State` with `Kind == "component"`: + +- `In` retains user-declared logical reference (for traceability) +- `DataType` is inferred from referenced component output when not explicitly declared +- `OutputDataType` is preserved if user declares explicit output type + +If inferred type is unavailable, a diagnostic is emitted (see diagnostics section). + +### 3) Type Propagation + +Referenced component route/resource types required for consuming component states SHOULD be appended to `plan.Result.Types` unless a collision exists. + +Collision policy: + +- Existing local type names win +- Emit collision diagnostic for skipped imported type + +### 4) Nested Component Dependencies + +If referenced route/resource parameters include additional component references: + +- Resolver walks nested dependencies transitively +- Cycle detection MUST prevent infinite recursion +- Cycle reports a deterministic warning diagnostic + +### 5) Loader Classification + +Shape-loaded component artifact SHOULD classify component dependencies as input-like contract dependencies (not miscellaneous "other"). + +## Diagnostics Contract + +Component-related diagnostics use `DQL-COMP-*` codes: + +- `DQL-COMP-REF-INVALID` +- `DQL-COMP-ROUTE-MISSING` +- `DQL-COMP-ROUTE-INVALID` +- `DQL-COMP-CYCLE` +- `DQL-COMP-TYPE-COLLISION` + +Requirements: + +- Deterministic code/message per failure class +- Span points at the referenced component token when available +- Warnings by default unless compile strict mode escalates + +## Non-Goals (Step 1) + +- No resolver implementation changes +- No signature engine wiring changes +- No compile pipeline behavior changes + +This step defines the contract only; implementation phases follow separately. + +## Acceptance Criteria + +The parity contract is considered defined when: + +1. Reference normalization rules are explicit and unambiguous. +2. Required state/type/diagnostic behavior is documented for success and failure paths. +3. Nested dependency and collision behavior is documented. +4. Constraints are independent of `internal/*` packages. diff --git a/repository/shape/compile/compiler.go b/repository/shape/compile/compiler.go index 6fb6eadc0..17118fe0c 100644 --- a/repository/shape/compile/compiler.go +++ b/repository/shape/compile/compiler.go @@ -58,7 +58,7 @@ func (c *DQLCompiler) Compile(ctx context.Context, source *shape.Source, opts .. root, compileDiags, err := c.compileRoot( source.Name, prepared.Pre.SQL, prepared.Statements, prepared.Decision, - compileOptions.MixedMode, compileOptions.UnknownNonReadMode, + compileOptions.MixedMode, compileOptions.UnknownNonReadMode, prepared.Pre.Directives, ) if err != nil { return nil, err @@ -133,18 +133,45 @@ func (c *DQLCompiler) assembleResult( result.Diagnostics = diags result.TypeContext = prepared.Pre.TypeCtx result.Directives = prepared.Pre.Directives - applyDefaultConnectorDirective(result) + applyConstDirective(result) hints := extractViewHints(source.DQL) - appendRelationViews(result, root, hints) + relationSQLSource := prepared.Pre.SQL + if strings.TrimSpace(relationSQLSource) == "" { + relationSQLSource = source.DQL + } + appendRelationViews(result, root, hints, relationSQLSource) appendDeclaredViews(source.DQL, result) + applyDefaultConnectorDirective(result) appendDeclaredStates(source.DQL, result) applyViewHints(result, hints) + result.Diagnostics = append(result.Diagnostics, appendComponentTypesWithLayout(source, result, pathLayout)...) + for _, item := range result.Views { + if item == nil || strings.TrimSpace(item.SQL) == "" { + continue + } + item.SQL = stripProjectionHintCalls(item.SQL) + } + applyInlineParamHints(source.DQL, result) applySourceParityEnrichmentWithLayout(result, source, pathLayout) - applyLinkedTypeSupport(result, source) + ensureDQLComponentRouteWithLayout(result, source, pathLayout) + applySummaryTypeSupport(result, source) + if compileOptions.UseLinkedTypes == nil || *compileOptions.UseLinkedTypes { + applyLinkedTypeSupport(result, source) + } result.Diagnostics = append(result.Diagnostics, applyColumnDiscoveryPolicy(result, compileOptions)...) return result } +func applyConstDirective(result *plan.Result) { + if result == nil || result.Directives == nil || len(result.Directives.Const) == 0 { + return + } + result.Const = make(map[string]string, len(result.Directives.Const)) + for k, v := range result.Directives.Const { + result.Const[k] = v + } +} + func applyDefaultConnectorDirective(result *plan.Result) { if result == nil || result.Directives == nil { return @@ -161,9 +188,14 @@ func applyDefaultConnectorDirective(result *plan.Result) { } } -func (c *DQLCompiler) compileRoot(sourceName, sqlText string, statements dqlstmt.Statements, decision pipeline.Decision, mode shape.CompileMixedMode, unknownMode shape.CompileUnknownNonReadMode) (*plan.View, []*dqlshape.Diagnostic, error) { +func (c *DQLCompiler) compileRoot(sourceName, sqlText string, statements dqlstmt.Statements, decision pipeline.Decision, mode shape.CompileMixedMode, unknownMode shape.CompileUnknownNonReadMode, directives *dqlshape.Directives) (*plan.View, []*dqlshape.Diagnostic, error) { mode = normalizeMixedMode(mode) unknownMode = normalizeUnknownNonReadMode(unknownMode) + consts := map[string]string(nil) + groupableAliases := explicitGroupableAliases(extractViewHints(sqlText)) + if directives != nil && len(directives.Const) > 0 { + consts = directives.Const + } if !decision.HasRead && !decision.HasExec && decision.HasUnknown { diag := &dqlshape.Diagnostic{ Code: dqldiag.CodeParseUnknownNonRead, @@ -199,7 +231,7 @@ func (c *DQLCompiler) compileRoot(sourceName, sqlText string, statements dqlstmt break } } - view, diags, err := pipeline.BuildRead(sourceName, readSQL) + view, diags, err := pipeline.BuildReadWithOptions(sourceName, readSQL, consts, groupableAliases) diags = append(diags, &dqlshape.Diagnostic{ Code: dqldiag.CodeDMLMixed, Severity: dqlshape.SeverityWarning, @@ -223,7 +255,28 @@ func (c *DQLCompiler) compileRoot(sourceName, sqlText string, statements dqlstmt } return view, diags, nil } - return pipeline.BuildRead(sourceName, sqlText) + return pipeline.BuildReadWithOptions(sourceName, sqlText, consts, groupableAliases) +} + +func explicitGroupableAliases(hints map[string]viewHint) map[string]bool { + if len(hints) == 0 { + return nil + } + result := map[string]bool{} + for alias, hint := range hints { + if hint.Groupable == nil || !*hint.Groupable { + continue + } + alias = strings.ToLower(strings.TrimSpace(alias)) + if alias == "" { + continue + } + result[alias] = true + } + if len(result) == 0 { + return nil + } + return result } func normalizeMixedMode(mode shape.CompileMixedMode) shape.CompileMixedMode { diff --git a/repository/shape/compile/compiler_test.go b/repository/shape/compile/compiler_test.go index 85b51108c..b84cd75e3 100644 --- a/repository/shape/compile/compiler_test.go +++ b/repository/shape/compile/compiler_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -137,6 +138,26 @@ SELECT id FROM ORDERS o assert.Equal(t, "analytics", planned.Views[0].Connector) } +func TestDQLCompiler_Compile_AppliesDefaultConnectorToDeclaredViews(t *testing.T) { + compiler := New() + dqlPath := filepath.Join("..", "..", "..", "e2e", "v1", "dql", "dev", "team", "user_team.dql") + dql, err := os.ReadFile(dqlPath) + require.NoError(t, err) + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "user_team", Path: dqlPath, DQL: string(dql)}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.GreaterOrEqual(t, len(planned.Views), 2) + connectors := map[string]string{} + for _, candidate := range planned.Views { + if candidate != nil { + connectors[candidate.Name] = candidate.Connector + } + } + assert.Equal(t, "dev", connectors["user_team"]) + assert.Equal(t, "dev", connectors["TeamStats"]) +} + func TestDQLCompiler_Compile_ColumnDiscoveryAutoForWildcard(t *testing.T) { compiler := New() res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT * FROM ORDERS o"}) @@ -229,6 +250,54 @@ func TestDQLCompiler_Compile_SyntaxError_RemapsAfterSanitize(t *testing.T) { } } +func TestDQLCompiler_Compile_RelationSQLUsesSanitizedVeltyOutput(t *testing.T) { + compiler := New() + dql := ` +#setting($_ = $route('/v1/api/shape/dev/vendors/{vendorID}', 'GET')) +#define($_ = $VendorID(path/vendorID)) +SELECT wrapper.* EXCEPT ID, + vendor.*, + products.* EXCEPT VENDOR_ID, + setting.* EXCEPT ID +FROM (SELECT ID FROM VENDOR WHERE ID = $VendorID) wrapper +JOIN (SELECT * FROM VENDOR t WHERE t.ID = $VendorID) vendor ON vendor.ID = wrapper.ID +JOIN (SELECT * FROM (SELECT (1) AS IS_ACTIVE, (3) AS CHANNEL, CAST($VendorID AS SIGNED) AS ID) t) setting ON setting.ID = wrapper.ID +JOIN (SELECT * FROM PRODUCT t) products ON products.VENDOR_ID = vendor.ID` + + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "vendor_details", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.Contains(t, planned.ViewsByName, "vendor") + require.Contains(t, planned.ViewsByName, "setting") + require.Contains(t, planned.ViewsByName, "products") + assert.Contains(t, planned.ViewsByName["vendor"].SQL, "$criteria.AppendBinding($Unsafe.VendorID)") + assert.Contains(t, planned.ViewsByName["setting"].SQL, "CAST($criteria.AppendBinding($Unsafe.VendorID) AS SIGNED)") + require.NotNil(t, planned.ViewsByName["setting"].Declaration) + require.Contains(t, planned.ViewsByName["setting"].Declaration.ColumnsConfig, "ID") + assert.Equal(t, `internal:"true"`, planned.ViewsByName["setting"].Declaration.ColumnsConfig["ID"].Tag) + require.NotNil(t, planned.ViewsByName["products"].Declaration) + require.Contains(t, planned.ViewsByName["products"].Declaration.ColumnsConfig, "VENDOR_ID") + assert.Equal(t, `internal:"true"`, planned.ViewsByName["products"].Declaration.ColumnsConfig["VENDOR_ID"].Tag) +} + +func TestDQLCompiler_Compile_PopulatesComponentRouteFromDirective(t *testing.T) { + compiler := New() + dql := ` +#setting($_ = $route('/v1/api/shape/dev/vendors/{vendorID}', 'DELETE')) +#define($_ = $VendorID(path/vendorID)) +SELECT ID FROM VENDOR WHERE ID = $VendorID` + + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "vendor_delete", DQL: dql}) + require.NoError(t, err) + planned, ok := plan.ResultFrom(res) + require.True(t, ok) + require.Len(t, planned.Components, 1) + assert.Equal(t, "DELETE", planned.Components[0].Method) + assert.Equal(t, "/v1/api/shape/dev/vendors/{vendorID}", planned.Components[0].RoutePath) + assert.Equal(t, planned.Views[0].Name, planned.Components[0].ViewName) +} + func TestDQLCompiler_Compile_DirectiveOnly_HasLineAndChar(t *testing.T) { compiler := New() _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "#package('x')"}) @@ -258,6 +327,22 @@ func TestDQLCompiler_Compile_InvalidDirective_HasLineAndChar(t *testing.T) { assert.Equal(t, 1, d.Span.Start.Char) } +func TestDQLCompiler_Compile_SQLSyntaxWithDirective_HasExactLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "#setting($_ = $route('/x', 'GET'))\nSELECT id FROM ORDERS WHERE (", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeParseSyntax, d.Code) + assert.Equal(t, 2, d.Span.Start.Line) + assert.Equal(t, 29, d.Span.Start.Char) +} + func TestDQLCompiler_Compile_ExtractsJoinLinks(t *testing.T) { compiler := New() dql := "SELECT o.id, i.sku FROM orders o JOIN order_items i ON o.id = i.order_id" @@ -354,8 +439,15 @@ SELECT id FROM ORDERS t` planned, ok := plan.ResultFrom(res) require.True(t, ok) require.Len(t, planned.Views, 2) - extra := planned.ViewsByName["e"] + var extra *plan.View + for _, item := range planned.Views { + if item != nil && strings.Contains(item.SQL, "SELECT code FROM EXTRA e") { + extra = item + break + } + } require.NotNil(t, extra) + assert.Equal(t, "Extra", extra.Name) assert.Equal(t, "EXTRA", extra.Table) assert.Contains(t, extra.SQL, "SELECT code FROM EXTRA e") } @@ -369,8 +461,15 @@ SELECT id FROM ORDERS t` require.NoError(t, err) planned, ok := plan.ResultFrom(res) require.True(t, ok) - extra := planned.ViewsByName["e"] + var extra *plan.View + for _, item := range planned.Views { + if item != nil && strings.Contains(item.SQL, "SELECT code FROM EXTRA e") { + extra = item + break + } + } require.NotNil(t, extra) + assert.Equal(t, "Extra", extra.Name) assert.Equal(t, "/v1/extra", extra.SQLURI) assert.Equal(t, "analytics", extra.Connector) assert.Equal(t, "one", extra.Cardinality) @@ -511,7 +610,7 @@ func TestDQLCompiler_Compile_MixedMode_ReadWins(t *testing.T) { require.NotEmpty(t, planned.Views) assert.Equal(t, "o", planned.Views[0].Name) assert.Equal(t, "ORDERS", planned.Views[0].Table) - assert.Contains(t, planned.Views[0].SQL, "SELECT o.id FROM ORDERS o") + assert.Contains(t, planned.Views[0].SQL, "SELECT * FROM ORDERS o") assert.NotContains(t, planned.Views[0].SQL, "UPDATE ORDERS") require.NotEmpty(t, planned.Diagnostics) assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) diff --git a/repository/shape/compile/component_route_shape.go b/repository/shape/compile/component_route_shape.go new file mode 100644 index 000000000..33ef9f6df --- /dev/null +++ b/repository/shape/compile/component_route_shape.go @@ -0,0 +1,59 @@ +package compile + +import ( + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func ensureDQLComponentRouteWithLayout(result *plan.Result, source *shape.Source, layout compilePathLayout) { + if result == nil || len(result.Components) > 0 { + return + } + root := firstPlannedView(result.Views) + if root == nil { + return + } + + settings := extractRuleSettings(source, result.Directives) + method := httpMethod(settings) + uri := strings.TrimSpace(settings.URI) + if uri == "" { + namespace := "" + if source != nil && strings.TrimSpace(source.Path) != "" { + namespace, _ = dqlToRouteNamespaceWithLayout(source.Path, layout) + } + if namespace != "" { + uri = inferDefaultURI(namespace) + } + if uri == "" && source != nil { + uri = normalizeURI(source.Name) + } + } + if uri == "" { + return + } + + name := root.Name + if source != nil && strings.TrimSpace(source.Name) != "" { + name = strings.TrimSpace(source.Name) + } + result.Components = []*plan.ComponentRoute{{ + Name: name, + ViewName: strings.TrimSpace(root.Name), + RoutePath: normalizeURI(uri), + Method: method, + Connector: strings.TrimSpace(root.Connector), + SourceURL: strings.TrimSpace(root.SQLURI), + }} +} + +func firstPlannedView(views []*plan.View) *plan.View { + for _, item := range views { + if item != nil { + return item + } + } + return nil +} diff --git a/repository/shape/compile/component_types.go b/repository/shape/compile/component_types.go index c7c8f22db..5ec5e3b28 100644 --- a/repository/shape/compile/component_types.go +++ b/repository/shape/compile/component_types.go @@ -1,15 +1,19 @@ package compile import ( + "context" "os" "path/filepath" + "reflect" "sort" "strings" "github.com/viant/datly/repository/shape" dqldiag "github.com/viant/datly/repository/shape/dql/diag" dqlshape "github.com/viant/datly/repository/shape/dql/shape" + shapeLoad "github.com/viant/datly/repository/shape/load" "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view" "github.com/viant/datly/view/state" "gopkg.in/yaml.v3" ) @@ -37,16 +41,15 @@ func appendComponentTypesWithLayout(source *shape.Source, result *plan.Result, l sourceNamespace, _ := dqlToRouteNamespaceWithLayout(source.Path, layout) collector := &componentCollector{ routesRoot: routesRoot, + dqlRoot: dqlRoot, + layout: layout, visited: map[string]componentVisitState{}, outputByRoute: map[string]string{}, + routeByNS: map[string]string{}, typesByName: map[string]*plan.Type{}, payloadCache: map[string]routePayloadLookup{}, reportedDiag: map[string]bool{}, } - if strings.TrimSpace(sourceNamespace) != "" { - collector.collect(sourceNamespace, relationSpan(source.DQL, 0), false) - } - for _, stateItem := range result.States { if stateItem == nil || state.Kind(strings.ToLower(stateItem.KindString())) != state.KindComponent { continue @@ -66,14 +69,38 @@ func appendComponentTypesWithLayout(source *shape.Source, result *plan.Result, l }) continue } + if routeKey, ok := collector.resolveRoute(ref, source.Path); ok && stateItem.In != nil { + stateItem.In.Name = routeKey + } outputType, ok := collector.collect(namespace, componentRefSpan(source.DQL, ref), true) + if routeKey := collector.routeKey(namespace); routeKey != "" && stateItem.In != nil { + stateItem.In.Name = routeKey + } if ok && strings.TrimSpace(outputType) != "" { if stateItem.Schema == nil { stateItem.Schema = &state.Schema{} } + if stateItem.Schema.Type() == nil { + if lookup, found := collector.payloadCache[strings.ToLower(strings.TrimSpace(namespace))]; found && lookup.outputType != nil { + stateItem.Schema.SetType(lookup.outputType) + } + } if strings.TrimSpace(stateItem.Schema.DataType) == "" { stateItem.Schema.DataType = strings.TrimSpace(outputType) } + if payload, found := collector.loadRoutePayload(namespace, componentRefSpan(source.DQL, ref)); found { + if pkg, modulePath := routeOutputPackage(payload, outputType); pkg != "" || modulePath != "" { + if strings.TrimSpace(stateItem.Schema.Package) == "" { + stateItem.Schema.Package = pkg + } + if strings.TrimSpace(stateItem.Schema.PackagePath) == "" { + stateItem.Schema.PackagePath = pkg + } + if strings.TrimSpace(stateItem.Schema.ModulePath) == "" { + stateItem.Schema.ModulePath = modulePath + } + } + } } } @@ -110,8 +137,13 @@ func appendComponentTypesWithLayout(source *shape.Source, result *plan.Result, l type componentCollector struct { routesRoot string + dqlRoot string + layout compilePathLayout + routeIndex *RouteIndex + routeIndexErr error visited map[string]componentVisitState outputByRoute map[string]string + routeByNS map[string]string // typesByName provides O(1) dedup; typeOrder tracks insertion sequence // so the final list can be sorted once rather than extracted from the map. typesByName map[string]*plan.Type @@ -123,6 +155,7 @@ type componentCollector struct { type routePayloadLookup struct { payload *routePayload + outputType reflect.Type found bool malformed bool malformedAt string @@ -187,6 +220,9 @@ func (c *componentCollector) collect(namespace string, span dqlshape.Span, requi outputType := routeOutputType(payload) c.outputByRoute[key] = outputType + if routeKey := routePayloadKey(payload); routeKey != "" { + c.routeByNS[key] = routeKey + } for _, param := range payload.Resource.Parameters { if !strings.EqualFold(strings.TrimSpace(param.In.Kind), string(state.KindComponent)) { @@ -210,6 +246,55 @@ func (c *componentCollector) collect(namespace string, span dqlshape.Span, requi return outputType, true } +func (c *componentCollector) routeKey(namespace string) string { + if c == nil { + return "" + } + return strings.TrimSpace(c.routeByNS[strings.ToLower(strings.TrimSpace(namespace))]) +} + +func (c *componentCollector) resolveRoute(ref, currentSource string) (string, bool) { + index, err := c.lazyRouteIndex() + if err != nil || index == nil { + return "", false + } + opts := c.layoutCompileOptions() + return index.Resolve(ref, currentSource, opts...) +} + +func (c *componentCollector) lazyRouteIndex() (*RouteIndex, error) { + if c == nil { + return nil, nil + } + if c.routeIndex != nil || c.routeIndexErr != nil { + return c.routeIndex, c.routeIndexErr + } + if strings.TrimSpace(c.dqlRoot) == "" { + return nil, nil + } + paths, err := collectDQLSources(c.dqlRoot) + if err != nil { + c.routeIndexErr = err + return nil, err + } + c.routeIndex, c.routeIndexErr = BuildRouteIndex(paths, c.layoutCompileOptions()...) + return c.routeIndex, c.routeIndexErr +} + +func (c *componentCollector) layoutCompileOptions() []shape.CompileOption { + if c == nil { + return nil + } + var opts []shape.CompileOption + if marker := strings.TrimSpace(c.layout.dqlMarker); marker != "" { + opts = append(opts, shape.WithDQLPathMarker(marker)) + } + if rel := strings.TrimSpace(c.layout.routesRelative); rel != "" { + opts = append(opts, shape.WithRoutesRelativePath(rel)) + } + return opts +} + func sourceRootsWithLayout(sourcePath string, layout compilePathLayout) (platformRoot, routesRoot, dqlRoot string, ok bool) { path := filepath.Clean(strings.TrimSpace(sourcePath)) if path == "" { @@ -348,6 +433,8 @@ type routePayload struct { } `yaml:"Parameters"` } `yaml:"Resource"` Routes []struct { + Method string `yaml:"Method"` + URI string `yaml:"URI"` Handler struct { OutputType string `yaml:"OutputType"` } `yaml:"Handler"` @@ -388,6 +475,82 @@ func readRoutePayload(routesRoot, namespace string) routePayloadLookup { return lookup } +func readDQLPayload(dqlRoot string, layout compilePathLayout, namespace string, sourcePath string) routePayloadLookup { + candidates := []string{} + if strings.TrimSpace(sourcePath) != "" { + candidates = append(candidates, sourcePath) + } + candidates = append(candidates, dqlSourceCandidates(dqlRoot, namespace)...) + lookup := routePayloadLookup{} + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + source := &shape.Source{ + Name: strings.TrimSuffix(filepath.Base(candidate), filepath.Ext(candidate)), + Path: candidate, + DQL: string(data), + } + opts := []shape.CompileOption{} + if marker := strings.TrimSpace(layout.dqlMarker); marker != "" { + opts = append(opts, shape.WithDQLPathMarker(marker)) + } + if rel := strings.TrimSpace(layout.routesRelative); rel != "" { + opts = append(opts, shape.WithRoutesRelativePath(rel)) + } + planned, err := New().Compile(context.Background(), source, opts...) + if err != nil { + if !lookup.malformed { + lookup.malformed = true + lookup.malformedAt = candidate + lookup.detail = strings.TrimSpace(err.Error()) + } + continue + } + result, ok := plan.ResultFrom(planned) + if !ok || result == nil { + if !lookup.malformed { + lookup.malformed = true + lookup.malformedAt = candidate + lookup.detail = "unexpected compiled plan result" + } + continue + } + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned) + if err != nil { + if !lookup.malformed { + lookup.malformed = true + lookup.malformedAt = candidate + lookup.detail = strings.TrimSpace(err.Error()) + } + continue + } + component, _ := shapeLoad.ComponentFrom(artifact) + lookup.payload = routePayloadFromPlan(result, component) + lookup.outputType = routeOutputReflectType(component, artifact.Resource) + applyDQLRoutePayload(lookup.payload, source, result, namespace) + lookup.found = true + lookup.malformed = false + lookup.malformedAt = "" + lookup.detail = "" + return lookup + } + return lookup +} + +func (c *componentCollector) componentSourcePath(namespace string) string { + index, err := c.lazyRouteIndex() + if err != nil || index == nil { + return "" + } + entries := index.ByNamespace[strings.ToLower(strings.TrimSpace(namespace))] + if len(entries) != 1 || entries[0] == nil { + return "" + } + return strings.TrimSpace(entries[0].SourcePath) +} + func (c *componentCollector) loadRoutePayload(namespace string, span dqlshape.Span) (*routePayload, bool) { key := strings.ToLower(strings.TrimSpace(namespace)) if key == "" { @@ -396,6 +559,12 @@ func (c *componentCollector) loadRoutePayload(namespace string, span dqlshape.Sp lookup, ok := c.payloadCache[key] if !ok { lookup = readRoutePayload(c.routesRoot, namespace) + if !lookup.found && strings.TrimSpace(c.dqlRoot) != "" { + dqlLookup := readDQLPayload(c.dqlRoot, c.layout, namespace, c.componentSourcePath(namespace)) + if dqlLookup.found || dqlLookup.malformed { + lookup = dqlLookup + } + } c.payloadCache[key] = lookup } if lookup.malformed && !lookup.found && !c.hasReported("invalid:"+key) { @@ -472,6 +641,54 @@ func routeOutputType(payload *routePayload) string { return "" } +func routeOutputPackage(payload *routePayload, outputType string) (string, string) { + if payload == nil { + return "", "" + } + if len(payload.Routes) > 0 { + if pkg := strings.TrimSpace(payload.Routes[0].Output.Type.Package); pkg != "" { + modulePath := routeTypeModulePath(payload, strings.TrimSpace(payload.Routes[0].Output.Type.Name)) + return pkg, modulePath + } + } + leaf := strings.Trim(strings.TrimSpace(outputType), "*") + if leaf == "" { + return "", "" + } + for _, item := range payload.Resource.Types { + name := strings.TrimSpace(item.Name) + dataType := strings.Trim(strings.TrimSpace(item.DataType), "*") + if strings.EqualFold(name, leaf) || strings.EqualFold(dataType, leaf) { + return strings.TrimSpace(item.Package), strings.TrimSpace(item.ModulePath) + } + } + for _, param := range payload.Resource.Parameters { + if !strings.EqualFold(strings.TrimSpace(param.In.Kind), string(state.KindOutput)) { + continue + } + if name := strings.Trim(strings.TrimSpace(param.Schema.Name), "*"); name == leaf { + return strings.TrimSpace(param.Schema.Package), "" + } + if dataType := strings.Trim(strings.TrimSpace(param.Schema.DataType), "*"); dataType == leaf { + return strings.TrimSpace(param.Schema.Package), "" + } + } + return "", "" +} + +func routeTypeModulePath(payload *routePayload, name string) string { + name = strings.Trim(strings.TrimSpace(name), "*") + if payload == nil || name == "" { + return "" + } + for _, item := range payload.Resource.Types { + if strings.EqualFold(strings.TrimSpace(item.Name), name) { + return strings.TrimSpace(item.ModulePath) + } + } + return "" +} + func componentRefSpan(raw, ref string) dqlshape.Span { offset := 0 ref = strings.TrimSpace(ref) @@ -494,3 +711,310 @@ func routeYAMLCandidates(routesRoot, namespace string) []string { filepath.Join(routesRoot, filepath.FromSlash(namespace), leaf+".yaml"), } } + +func routePayloadKey(payload *routePayload) string { + if payload == nil { + return "" + } + for _, route := range payload.Routes { + if uri := strings.TrimSpace(route.URI); uri != "" { + return normalizeRouteKey(strings.TrimSpace(route.Method), uri) + } + } + return "" +} + +func collectDQLSources(root string) ([]string, error) { + var result []string + err := filepath.WalkDir(root, func(candidate string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + if !isComponentDQLSourceFile(candidate) { + return nil + } + result = append(result, candidate) + return nil + }) + if err != nil { + return nil, err + } + sort.Strings(result) + return result, nil +} + +func isComponentDQLSourceFile(path string) bool { + ext := strings.ToLower(strings.TrimSpace(filepath.Ext(path))) + return ext == ".dql" || ext == ".sql" +} + +func dqlSourceCandidates(dqlRoot, namespace string) []string { + namespace = strings.Trim(namespace, "/") + if namespace == "" || strings.TrimSpace(dqlRoot) == "" { + return nil + } + leaf := filepath.Base(namespace) + base := filepath.Join(dqlRoot, filepath.FromSlash(namespace)) + return []string{ + base + ".dql", + base + ".sql", + filepath.Join(base, leaf+".dql"), + filepath.Join(base, leaf+".sql"), + } +} + +func routePayloadFromPlan(result *plan.Result, component *shapeLoad.Component) *routePayload { + if result == nil { + return nil + } + payload := &routePayload{} + for _, item := range result.Types { + if item == nil { + continue + } + payload.Resource.Types = append(payload.Resource.Types, struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + }{ + Name: strings.TrimSpace(item.Name), + Alias: strings.TrimSpace(item.Alias), + DataType: strings.TrimSpace(item.DataType), + Cardinality: strings.TrimSpace(item.Cardinality), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }) + } + ensureComponentOutputType(payload, component) + for _, item := range result.States { + if item == nil { + continue + } + param := struct { + Name string `yaml:"Name"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + Schema struct { + DataType string `yaml:"DataType"` + Name string `yaml:"Name"` + Package string `yaml:"Package"` + Cardinality string `yaml:"Cardinality"` + } `yaml:"Schema"` + }{Name: strings.TrimSpace(item.Name)} + if item.In != nil { + param.In.Kind = string(item.In.Kind) + param.In.Name = strings.TrimSpace(item.In.Name) + } + if item.Schema != nil { + param.Schema.DataType = strings.TrimSpace(item.Schema.DataType) + param.Schema.Name = strings.TrimSpace(item.Schema.Name) + param.Schema.Package = strings.TrimSpace(item.Schema.Package) + param.Schema.Cardinality = string(item.Schema.Cardinality) + } + payload.Resource.Parameters = append(payload.Resource.Parameters, param) + } + if outputType := componentOutputType(component, result); outputType != "" { + payload.Routes = append(payload.Routes, struct { + Method string `yaml:"Method"` + URI string `yaml:"URI"` + Handler struct { + OutputType string `yaml:"OutputType"` + } `yaml:"Handler"` + Output struct { + Cardinality string `yaml:"Cardinality"` + Type struct { + Name string `yaml:"Name"` + Package string `yaml:"Package"` + } `yaml:"Type"` + } `yaml:"Output"` + }{}) + payload.Routes[0].Handler.OutputType = outputType + if name, pkg := componentOutputName(component); name != "" { + payload.Routes[0].Output.Type.Name = name + payload.Routes[0].Output.Type.Package = pkg + } + } + return payload +} + +func componentOutputType(component *shapeLoad.Component, result *plan.Result) string { + if name, _ := componentOutputName(component); name != "" { + return "*" + strings.Trim(name, "*") + } + if component != nil { + for _, item := range component.Output { + if item == nil { + continue + } + if outputType := strings.TrimSpace(item.OutputDataType); outputType != "" { + return outputType + } + if item.Schema != nil { + if dataType := strings.TrimSpace(item.Schema.DataType); dataType != "" { + return dataType + } + if name := strings.TrimSpace(item.Schema.Name); name != "" { + return "*" + strings.Trim(name, "*") + } + } + } + } + return planOutputType(result) +} + +func routeOutputReflectType(component *shapeLoad.Component, resource *view.Resource) reflect.Type { + if component == nil || resource == nil { + return nil + } + pkgPath := "" + if component.TypeContext != nil { + pkgPath = strings.TrimSpace(component.TypeContext.PackagePath) + if pkgPath == "" { + pkgPath = strings.TrimSpace(component.TypeContext.DefaultPackage) + } + } + params := resource.Parameters.FilterByKind(state.KindOutput) + if len(params) == 0 { + params = component.OutputParameters() + } + if len(params) == 0 { + return nil + } + rt, err := params.ReflectType(pkgPath, resource.LookupType()) + if err != nil { + return nil + } + return rt +} + +func ensureComponentOutputType(payload *routePayload, component *shapeLoad.Component) { + if payload == nil || component == nil { + return + } + name, pkg := componentOutputName(component) + modulePath := "" + if component.TypeContext != nil { + modulePath = strings.TrimSpace(component.TypeContext.PackagePath) + } + if name == "" || pkg == "" || modulePath == "" { + return + } + for _, item := range payload.Resource.Types { + if strings.EqualFold(strings.TrimSpace(item.Name), name) { + return + } + } + payload.Resource.Types = append(payload.Resource.Types, struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + }{ + Name: name, + DataType: "*" + name, + Package: pkg, + ModulePath: modulePath, + }) +} + +func componentOutputName(component *shapeLoad.Component) (string, string) { + if component == nil { + return "", "" + } + name := generatedComponentTypeBase(component) + "Output" + if spec := component.TypeSpecs["output"]; spec != nil && strings.TrimSpace(spec.TypeName) != "" { + name = strings.TrimSpace(spec.TypeName) + } + pkg := "" + if component.TypeContext != nil { + pkg = strings.TrimSpace(component.TypeContext.PackagePath) + if pkg == "" { + pkg = strings.TrimSpace(component.TypeContext.DefaultPackage) + } + } + return name, pkg +} + +func generatedComponentTypeBase(component *shapeLoad.Component) string { + if component == nil { + return "Component" + } + name := strings.TrimSpace(component.RootView) + if name == "" { + name = strings.TrimSpace(component.Name) + } + if name == "" { + name = "Component" + } + return state.SanitizeTypeName(name) +} + +func planOutputType(result *plan.Result) string { + if result == nil { + return "" + } + for _, item := range result.States { + if item == nil || !strings.EqualFold(item.KindString(), string(state.KindOutput)) { + continue + } + if outputType := strings.TrimSpace(item.OutputDataType); outputType != "" { + return outputType + } + if item.Schema != nil { + if dataType := strings.TrimSpace(item.Schema.DataType); dataType != "" { + return dataType + } + if name := strings.TrimSpace(item.Schema.Name); name != "" { + return "*" + strings.Trim(name, "*") + } + } + } + for _, item := range result.Types { + if item == nil || !strings.EqualFold(strings.TrimSpace(item.Name), "Output") { + continue + } + if dataType := strings.TrimSpace(item.DataType); dataType != "" { + return dataType + } + return "*Output" + } + return "" +} + +func applyDQLRoutePayload(payload *routePayload, source *shape.Source, result *plan.Result, namespace string) { + if payload == nil || len(payload.Routes) == 0 { + return + } + settings := extractRuleSettings(source, nil) + if result != nil { + settings = extractRuleSettings(source, result.Directives) + } + method := httpMethod(settings) + uri := strings.TrimSpace(settings.URI) + if uri == "" { + uri = inferDefaultURI(namespace) + } + payload.Routes[0].Method = method + payload.Routes[0].URI = normalizeURI(uri) +} + +func httpMethod(settings *ruleSettings) string { + if settings == nil { + return "GET" + } + methods := parseRouteMethods(settings.Method) + if len(methods) == 0 { + return "GET" + } + return strings.ToUpper(strings.TrimSpace(methods[0])) +} diff --git a/repository/shape/compile/component_types_test.go b/repository/shape/compile/component_types_test.go index 2cf803b7f..090f50366 100644 --- a/repository/shape/compile/component_types_test.go +++ b/repository/shape/compile/component_types_test.go @@ -203,3 +203,32 @@ func TestAppendComponentTypes_InvalidRouteYAMLDedupedForRepeatedStates(t *testin } assert.Equal(t, 1, invalidCount) } + +func TestAppendComponentTypes_FallsBackToSiblingDQLComponent(t *testing.T) { + temp := t.TempDir() + sourceDir := filepath.Join(temp, "dql", "dev", "vendor") + refDir := filepath.Join(temp, "dql", "dev") + require.NoError(t, os.MkdirAll(sourceDir, 0o755)) + require.NoError(t, os.MkdirAll(refDir, 0o755)) + + sourcePath := filepath.Join(sourceDir, "vendors.dql") + refPath := filepath.Join(refDir, "user_acl.dql") + sourceDQL := "#define($_ = $Auth(component/../user_acl))\nSELECT 1" + refDQL := "#package('github.com/viant/datly/e2e/v1/shape/dev/vendor/user_acl')\n#setting($_ = $route('/v1/api/dev/user-acl', 'GET'))\n#define($_ = $Auth(output/view).Embed())\nSELECT 1 AS UserID, TRUE AS IsReadOnly, TRUE AS Feature1" + require.NoError(t, os.WriteFile(sourcePath, []byte(sourceDQL), 0o644)) + require.NoError(t, os.WriteFile(refPath, []byte(refDQL), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Parameter: state.Parameter{Name: "Auth", In: &state.Location{Kind: state.KindComponent, Name: "../user_acl"}}}, + }, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: sourceDQL}, result) + for _, item := range diags { + require.NotEqual(t, dqldiag.CodeCompRouteMissing, item.Code) + } + require.NotNil(t, result.States[0].Schema) + assert.Equal(t, "*UserAclOutput", result.States[0].Schema.DataType) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/vendor/user_acl", result.States[0].Schema.Package) + assert.Equal(t, "GET:/v1/api/dev/user-acl", result.States[0].In.Name) +} diff --git a/repository/shape/compile/enrich.go b/repository/shape/compile/enrich.go index 3ea8a277b..c36ac8ef4 100644 --- a/repository/shape/compile/enrich.go +++ b/repository/shape/compile/enrich.go @@ -66,6 +66,9 @@ func buildParityEnrichmentContext(result *plan.Result, source *shape.Source, lay joinEmbedRefs: map[string]string{}, joinSubqueryBodies: map[string]string{}, } + if rootBaseDir := resultRootSQLBaseDir(result); rootBaseDir != "" { + ctx.baseDir = rootBaseDir + } if len(result.Views) == 0 || result.Views[0] == nil { return ctx } @@ -78,6 +81,17 @@ func buildParityEnrichmentContext(result *plan.Result, source *shape.Source, lay return ctx } +func resultRootSQLBaseDir(result *plan.Result) string { + if result == nil || len(result.Views) == 0 || result.Views[0] == nil { + return "" + } + rootName := strings.TrimSpace(result.Views[0].Name) + if rootName == "" { + return "" + } + return rootName +} + func applyViewDefaults(item *plan.View, root bool, ctx *parityEnrichmentContext) { if item == nil || ctx == nil { return diff --git a/repository/shape/compile/enrich_test.go b/repository/shape/compile/enrich_test.go index 1c532d398..b62522c37 100644 --- a/repository/shape/compile/enrich_test.go +++ b/repository/shape/compile/enrich_test.go @@ -45,6 +45,24 @@ func TestApplySourceParityEnrichment_InferTableFromSubquery(t *testing.T) { require.Equal(t, "advertiser/advertiser.sql", result.Views[0].SQLURI) } +func TestApplySourceParityEnrichment_UsesRootViewNameForSQLBaseDir(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/dev/vendor/child_meta.dql", + DQL: `SELECT vendor.*, products.* FROM VENDOR vendor JOIN PRODUCT products ON products.VENDOR_ID = vendor.ID`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "vendor", Table: "VENDOR", SQL: "SELECT * FROM VENDOR"}, + {Name: "products", Table: "PRODUCT", SQL: "SELECT * FROM PRODUCT"}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "vendor/vendor.sql", result.Views[0].SQLURI) + require.Equal(t, "vendor/products.sql", result.Views[1].SQLURI) +} + func TestApplySourceParityEnrichment_InferTableFromEmbed(t *testing.T) { tempDir := t.TempDir() dqlDir := filepath.Join(tempDir, "dql", "platform", "timezone") diff --git a/repository/shape/compile/hints.go b/repository/shape/compile/hints.go index a16b855d2..117209e13 100644 --- a/repository/shape/compile/hints.go +++ b/repository/shape/compile/hints.go @@ -5,13 +5,23 @@ import ( "strconv" "strings" + "github.com/viant/datly/repository/shape/dql/decl" "github.com/viant/datly/repository/shape/plan" ) type viewHint struct { - Connector string - AllowNulls *bool - NoLimit *bool + Connector string + AllowNulls *bool + Groupable *bool + NoLimit *bool + CacheRef string + Limit *int + Cardinality string + Dest string + TypeName string + Self *plan.SelfReference + SelectorOrderBy *bool + SelectorOrderByNames map[string]string } func extractViewHints(dql string) map[string]viewHint { @@ -22,7 +32,7 @@ func extractViewHints(dql string) map[string]viewHint { if len(call.args) != 2 { continue } - alias := strings.TrimSpace(call.args[0]) + alias := normalizeHintAlias(call.args[0]) connector := unquote(strings.TrimSpace(call.args[1])) if !isIdentifier(alias) || !isIdentifier(connector) { continue @@ -34,7 +44,7 @@ func extractViewHints(dql string) map[string]viewHint { if len(call.args) != 1 { continue } - alias := strings.TrimSpace(call.args[0]) + alias := normalizeHintAlias(call.args[0]) if !isIdentifier(alias) { continue } @@ -42,11 +52,40 @@ func extractViewHints(dql string) map[string]viewHint { value := true hint.AllowNulls = &value result[alias] = hint + case "groupable", "grouping_enabled": + if len(call.args) != 1 { + continue + } + alias := normalizeHintAlias(call.args[0]) + if !isIdentifier(alias) { + continue + } + hint := result[alias] + value := true + hint.Groupable = &value + result[alias] = hint + case "allowed_order_by_columns": + if len(call.args) != 2 { + continue + } + alias := normalizeHintAlias(call.args[0]) + columns := strings.TrimSpace(unquote(strings.TrimSpace(call.args[1]))) + if !isIdentifier(alias) || columns == "" { + continue + } + hint := result[alias] + value := true + hint.SelectorOrderBy = &value + if hint.SelectorOrderByNames == nil { + hint.SelectorOrderByNames = map[string]string{} + } + appendAllowedOrderByColumns(hint.SelectorOrderByNames, columns) + result[alias] = hint case "set_limit": if len(call.args) != 2 { continue } - alias := strings.TrimSpace(call.args[0]) + alias := normalizeHintAlias(call.args[0]) limitRaw := strings.TrimSpace(call.args[1]) if !isIdentifier(alias) || limitRaw == "" { continue @@ -58,6 +97,74 @@ func extractViewHints(dql string) map[string]viewHint { hint := result[alias] noLimit := limit == 0 hint.NoLimit = &noLimit + if limit > 0 { + hint.Limit = &limit + } + result[alias] = hint + case "set_cache": + if len(call.args) != 2 { + continue + } + alias := normalizeHintAlias(call.args[0]) + ref := unquote(strings.TrimSpace(call.args[1])) + if !isIdentifier(alias) || ref == "" { + continue + } + hint := result[alias] + hint.CacheRef = ref + result[alias] = hint + case "cardinality": + if len(call.args) != 2 { + continue + } + alias := normalizeHintAlias(call.args[0]) + value := strings.ToLower(strings.TrimSpace(unquote(strings.TrimSpace(call.args[1])))) + if !isIdentifier(alias) { + continue + } + if value != "one" && value != "many" { + continue + } + hint := result[alias] + hint.Cardinality = value + result[alias] = hint + case "self_ref": + if len(call.args) != 4 { + continue + } + alias := normalizeHintAlias(call.args[0]) + holder := unquote(strings.TrimSpace(call.args[1])) + child := unquote(strings.TrimSpace(call.args[2])) + parent := unquote(strings.TrimSpace(call.args[3])) + if alias == "" || holder == "" || child == "" || parent == "" { + continue + } + hint := result[alias] + hint.Self = &plan.SelfReference{Holder: holder, Child: child, Parent: parent} + result[alias] = hint + case "dest": + if len(call.args) != 2 { + continue + } + alias := normalizeHintAlias(call.args[0]) + dest := strings.TrimSpace(unquote(strings.TrimSpace(call.args[1]))) + if !isIdentifier(alias) || dest == "" { + continue + } + hint := result[alias] + hint.Dest = dest + result[alias] = hint + case "type": + if len(call.args) != 2 { + continue + } + alias := normalizeHintAlias(call.args[0]) + typeName := strings.TrimSpace(unquote(strings.TrimSpace(call.args[1]))) + if !isIdentifier(alias) || typeName == "" { + continue + } + hint := result[alias] + hint.TypeName = typeName result[alias] = hint } } @@ -70,115 +177,30 @@ type hintCall struct { } func scanHintCalls(input string) []hintCall { - result := make([]hintCall, 0) - for i := 0; i < len(input); { - if !isIdentifierStart(input[i]) { - i++ - continue - } - start := i - i++ - for i < len(input) && isIdentifierPart(input[i]) { - i++ - } - name := strings.ToLower(input[start:i]) - if name != "use_connector" && name != "allow_nulls" && name != "set_limit" { - continue - } - j := skipSpaces(input, i) - if j >= len(input) || input[j] != '(' { - continue - } - body, end, ok := readCallBody(input, j) - if !ok { - continue - } - result = append(result, hintCall{name: name, args: splitCallArgs(body)}) - i = end + 1 - } - return result -} - -func readCallBody(input string, openParen int) (string, int, bool) { - depth := 0 - quote := byte(0) - for i := openParen; i < len(input); i++ { - ch := input[i] - if quote != 0 { - if ch == '\\' && i+1 < len(input) { - i++ - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '\'' || ch == '"' { - quote = ch - continue - } - if ch == '(' { - depth++ - continue - } - if ch == ')' { - depth-- - if depth == 0 { - return input[openParen+1 : i], i, true - } - } - } - return "", -1, false -} - -func splitCallArgs(input string) []string { - args := make([]string, 0) - current := strings.Builder{} - depth := 0 - quote := byte(0) - for i := 0; i < len(input); i++ { - ch := input[i] - if quote != 0 { - current.WriteByte(ch) - if ch == '\\' && i+1 < len(input) { - i++ - current.WriteByte(input[i]) - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '\'' || ch == '"' { - quote = ch - current.WriteByte(ch) - continue - } - if ch == '(' { - depth++ - current.WriteByte(ch) - continue - } - if ch == ')' { - if depth > 0 { - depth-- - } - current.WriteByte(ch) - continue - } - if ch == ',' && depth == 0 { - args = append(args, strings.TrimSpace(current.String())) - current.Reset() - continue - } - current.WriteByte(ch) + names := map[string]bool{ + "use_connector": true, + "allow_nulls": true, + "groupable": true, + "grouping_enabled": true, + "allowed_order_by_columns": true, + "set_limit": true, + "set_cache": true, + "cardinality": true, + "self_ref": true, + "dest": true, + "type": true, } - if value := strings.TrimSpace(current.String()); value != "" { - args = append(args, value) + parsed, _ := decl.ScanCalls(input, decl.CallScanOptions{ + AllowedNames: names, + RequireDollar: false, + AllowDollar: false, + Strict: false, + }) + result := make([]hintCall, 0, len(parsed)) + for _, call := range parsed { + result = append(result, hintCall{name: call.Name, args: call.Args}) } - return args + return result } func isIdentifierStart(ch byte) bool { @@ -213,22 +235,19 @@ func unquote(value string) string { return value } -func skipSpaces(input string, index int) int { - for index < len(input) { - switch input[index] { - case ' ', '\t', '\n', '\r': - index++ - default: - return index - } - } - return index -} - -func appendRelationViews(result *plan.Result, root *plan.View, hints map[string]viewHint) { +func appendRelationViews(result *plan.Result, root *plan.View, hints map[string]viewHint, rawDQL string) { if result == nil || root == nil || len(root.Relations) == 0 { return } + joinSQLByAlias := map[string]string{} + for _, item := range scanJoinSubqueries(rawDQL) { + alias := strings.TrimSpace(item.alias) + body := strings.TrimSpace(item.body) + if alias == "" || body == "" { + continue + } + joinSQLByAlias[strings.ToLower(alias)] = body + } for _, relation := range root.Relations { if relation == nil { continue @@ -247,6 +266,10 @@ func appendRelationViews(result *plan.Result, root *plan.View, hints map[string] continue } table := strings.TrimSpace(relation.Table) + sqlText := strings.TrimSpace(joinSQLByAlias[strings.ToLower(name)]) + if sqlText == "" { + sqlText = relationSQLText(table) + } if table == "" { table = name } @@ -256,10 +279,14 @@ func appendRelationViews(result *plan.Result, root *plan.View, hints map[string] Holder: name, Name: name, Table: table, + SQL: sqlText, Cardinality: "many", FieldType: reflect.TypeOf([]map[string]interface{}{}), ElementType: reflect.TypeOf(map[string]interface{}{}), } + if len(relation.ColumnsConfig) > 0 { + view.Declaration = &plan.ViewDeclaration{ColumnsConfig: relation.ColumnsConfig} + } result.Views = append(result.Views, view) result.ViewsByName[name] = view } @@ -277,11 +304,7 @@ func applyViewHints(result *plan.Result, hints map[string]viewHint) { continue } for _, key := range []string{item.Name, item.Holder} { - key = strings.TrimSpace(key) - if key == "" { - continue - } - hint, ok := hints[key] + hint, ok := lookupViewHint(hints, key) if !ok { continue } @@ -292,14 +315,92 @@ func applyViewHints(result *plan.Result, hints map[string]viewHint) { value := *hint.AllowNulls item.AllowNulls = &value } + if item.Groupable == nil && hint.Groupable != nil { + value := *hint.Groupable + item.Groupable = &value + } if item.SelectorNoLimit == nil && hint.NoLimit != nil { value := *hint.NoLimit item.SelectorNoLimit = &value } + if item.SelectorOrderBy == nil && hint.SelectorOrderBy != nil { + value := *hint.SelectorOrderBy + item.SelectorOrderBy = &value + } + if len(item.SelectorOrderByColumns) == 0 && len(hint.SelectorOrderByNames) > 0 { + item.SelectorOrderByColumns = map[string]string{} + for key, value := range hint.SelectorOrderByNames { + item.SelectorOrderByColumns[key] = value + } + } + if item.SelectorLimit == nil && hint.Limit != nil { + value := *hint.Limit + item.SelectorLimit = &value + } + if item.CacheRef == "" && hint.CacheRef != "" { + item.CacheRef = hint.CacheRef + } + if hint.Cardinality != "" { + item.Cardinality = hint.Cardinality + } + if item.Self == nil && hint.Self != nil { + item.Self = hint.Self + } + if hint.Dest != "" || hint.TypeName != "" { + if item.Declaration == nil { + item.Declaration = &plan.ViewDeclaration{} + } + if item.Declaration.Dest == "" && hint.Dest != "" { + item.Declaration.Dest = hint.Dest + } + if item.Declaration.TypeName == "" && hint.TypeName != "" { + item.Declaration.TypeName = hint.TypeName + } + } } } } +func normalizeHintAlias(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func appendAllowedOrderByColumns(target map[string]string, columns string) { + for _, expression := range strings.Split(columns, ",") { + expression = strings.TrimSpace(expression) + if expression == "" { + continue + } + key := expression + value := expression + if strings.Contains(expression, ":") { + parts := strings.SplitN(expression, ":", 2) + key = strings.TrimSpace(parts[0]) + value = strings.TrimSpace(parts[1]) + } + if key == "" || value == "" { + continue + } + target[key] = value + lcKey := strings.ToLower(key) + if lcKey != key { + target[lcKey] = value + } + if index := strings.Index(key, "."); index != -1 && index+1 < len(key) { + target[key[index+1:]] = value + } + } +} + +func lookupViewHint(hints map[string]viewHint, key string) (viewHint, bool) { + key = normalizeHintAlias(key) + if key == "" { + return viewHint{}, false + } + hint, ok := hints[key] + return hint, ok +} + func normalizeRelationTable(table string) string { table = strings.TrimSpace(table) if table == "" { @@ -332,3 +433,65 @@ func normalizeRelationTable(table string) string { } return normalized } + +func relationSQLText(table string) string { + trimmed := strings.TrimSpace(table) + if trimmed == "" { + return "" + } + normalized := strings.ToLower(trimmed) + if strings.HasPrefix(normalized, "select ") { + return trimmed + } + if strings.HasPrefix(trimmed, "(") { + unwrapped := unwrapRelationParens(trimmed) + unwrappedLower := strings.ToLower(strings.TrimSpace(unwrapped)) + if strings.HasPrefix(unwrappedLower, "select ") { + return strings.TrimSpace(unwrapped) + } + } + return "" +} + +func unwrapRelationParens(input string) string { + input = strings.TrimSpace(input) + if len(input) < 2 || input[0] != '(' || input[len(input)-1] != ')' { + return input + } + depth := 0 + quote := byte(0) + for i := 0; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + switch ch { + case '(': + depth++ + case ')': + depth-- + if depth == 0 && i != len(input)-1 { + return input + } + } + } + if depth != 0 { + return input + } + inner := strings.TrimSpace(input[1 : len(input)-1]) + if inner == "" { + return input + } + return inner +} diff --git a/repository/shape/compile/hints_strip.go b/repository/shape/compile/hints_strip.go new file mode 100644 index 000000000..080e78cab --- /dev/null +++ b/repository/shape/compile/hints_strip.go @@ -0,0 +1,83 @@ +package compile + +import ( + "strings" + + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/query" +) + +var projectionHintCalls = map[string]bool{ + "useconnector": true, + "allownulls": true, + "groupable": true, + "groupingenabled": true, + "allowedorderbycolumns": true, + "setlimit": true, + "setcache": true, + "cardinality": true, + "selfref": true, + "dest": true, + "type": true, +} + +// stripProjectionHintCalls removes hint-only projection functions (e.g. self_ref, dest) +// from executable SQL while preserving metadata extraction from original DQL. +func stripProjectionHintCalls(sqlText string) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return sqlText + } + queryNode, err := sqlparser.ParseQuery(sqlText) + if err != nil || queryNode == nil { + return sqlText + } + if !stripHintCallsFromSelect(queryNode) { + return sqlText + } + return strings.TrimSpace(sqlparser.Stringify(queryNode)) +} + +func stripHintCallsFromSelect(node *query.Select) bool { + if node == nil || len(node.List) == 0 { + return false + } + filtered := make(query.List, 0, len(node.List)) + changed := false + for _, item := range node.List { + if item == nil { + continue + } + if isHintProjectionItem(item) { + changed = true + continue + } + filtered = append(filtered, item) + } + // Keep original list if stripping would produce invalid SELECT list. + if changed && len(filtered) > 0 { + node.List = filtered + return true + } + return false +} + +func isHintProjectionItem(item *query.Item) bool { + if item == nil || item.Expr == nil { + return false + } + call, ok := item.Expr.(*expr.Call) + if !ok || call.X == nil { + return false + } + name := normalizeHintCallName(sqlparser.Stringify(call.X)) + return projectionHintCalls[name] +} + +func normalizeHintCallName(name string) string { + name = strings.ToLower(strings.TrimSpace(name)) + name = strings.Trim(name, "`\"'") + name = strings.ReplaceAll(name, "_", "") + return name +} diff --git a/repository/shape/compile/hints_test.go b/repository/shape/compile/hints_test.go index c40470e0f..72f986770 100644 --- a/repository/shape/compile/hints_test.go +++ b/repository/shape/compile/hints_test.go @@ -1,6 +1,7 @@ package compile import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -9,17 +10,42 @@ import ( ) func TestExtractViewHints_WithQuotedConnector(t *testing.T) { - dql := "SELECT use_connector(match, 'bq_sitemgmt_match'), use_connector(site, \"ci_ads\"), allow_nulls(match), set_limit(match, 0)" + dql := "SELECT use_connector(match, 'bq_sitemgmt_match'), use_connector(site, \"ci_ads\"), allow_nulls(match), groupable(match), set_limit(match, 0)" hints := extractViewHints(dql) require.Len(t, hints, 2) assert.Equal(t, "bq_sitemgmt_match", hints["match"].Connector) assert.Equal(t, "ci_ads", hints["site"].Connector) require.NotNil(t, hints["match"].AllowNulls) assert.True(t, *hints["match"].AllowNulls) + require.NotNil(t, hints["match"].Groupable) + assert.True(t, *hints["match"].Groupable) require.NotNil(t, hints["match"].NoLimit) assert.True(t, *hints["match"].NoLimit) } +func TestExtractViewHints_GroupingEnabledAlias(t *testing.T) { + dql := "SELECT grouping_enabled(match), set_limit(match, 0)" + hints := extractViewHints(dql) + require.Contains(t, hints, "match") + require.NotNil(t, hints["match"].Groupable) + assert.True(t, *hints["match"].Groupable) +} + +func TestExtractViewHints_AllowedOrderByColumns(t *testing.T) { + dql := "SELECT allowed_order_by_columns(vendor, 'accountId:ACCOUNT_ID,vendor.userCreated:USER_CREATED,totalId:TOTAL_ID')" + hints := extractViewHints(dql) + require.Contains(t, hints, "vendor") + require.NotNil(t, hints["vendor"].SelectorOrderBy) + assert.True(t, *hints["vendor"].SelectorOrderBy) + assert.Equal(t, "ACCOUNT_ID", hints["vendor"].SelectorOrderByNames["accountId"]) + assert.Equal(t, "ACCOUNT_ID", hints["vendor"].SelectorOrderByNames["accountid"]) + assert.Equal(t, "USER_CREATED", hints["vendor"].SelectorOrderByNames["vendor.userCreated"]) + assert.Equal(t, "USER_CREATED", hints["vendor"].SelectorOrderByNames["vendor.usercreated"]) + assert.Equal(t, "USER_CREATED", hints["vendor"].SelectorOrderByNames["userCreated"]) + assert.Equal(t, "TOTAL_ID", hints["vendor"].SelectorOrderByNames["totalId"]) + assert.Equal(t, "TOTAL_ID", hints["vendor"].SelectorOrderByNames["totalid"]) +} + func TestExtractViewHints_MixedCaseAndUnquotedConnector(t *testing.T) { dql := "SELECT USE_CONNECTOR(match, ci_ads), Allow_Nulls(match), set_limit(match, -1)" hints := extractViewHints(dql) @@ -31,24 +57,263 @@ func TestExtractViewHints_MixedCaseAndUnquotedConnector(t *testing.T) { assert.False(t, *hints["match"].NoLimit) } +func TestExtractViewHints_DestAndType(t *testing.T) { + dql := "SELECT dest(vendor,'vendor.go'), type(vendor,'Vendor'), dest(products,'vendor.go'), type(products,'Products') FROM VENDOR vendor" + hints := extractViewHints(dql) + require.Contains(t, hints, "vendor") + require.Contains(t, hints, "products") + assert.Equal(t, "vendor.go", hints["vendor"].Dest) + assert.Equal(t, "Vendor", hints["vendor"].TypeName) + assert.Equal(t, "vendor.go", hints["products"].Dest) + assert.Equal(t, "Products", hints["products"].TypeName) +} + +func TestExtractViewHints_Cardinality(t *testing.T) { + dql := "SELECT cardinality(products_meta, 'one'), cardinality(products, 'many')" + hints := extractViewHints(dql) + require.Contains(t, hints, "products_meta") + require.Contains(t, hints, "products") + assert.Equal(t, "one", hints["products_meta"].Cardinality) + assert.Equal(t, "many", hints["products"].Cardinality) +} + func TestApplyViewHints_Metadata(t *testing.T) { trueValue := true result := &plan.Result{ Views: []*plan.View{ - {Name: "match", Table: "MATCH"}, + {Name: "match", Table: "MATCH", Cardinality: "many"}, }, } applyViewHints(result, map[string]viewHint{ "match": { - Connector: "ci_ads", - AllowNulls: &trueValue, - NoLimit: &trueValue, + Connector: "ci_ads", + AllowNulls: &trueValue, + Groupable: &trueValue, + NoLimit: &trueValue, + Cardinality: "one", + Dest: "match.go", + TypeName: "Match", + SelectorOrderBy: &trueValue, + SelectorOrderByNames: map[string]string{ + "accountId": "ACCOUNT_ID", + }, }, }) require.Len(t, result.Views, 1) assert.Equal(t, "ci_ads", result.Views[0].Connector) require.NotNil(t, result.Views[0].AllowNulls) assert.True(t, *result.Views[0].AllowNulls) + require.NotNil(t, result.Views[0].Groupable) + assert.True(t, *result.Views[0].Groupable) require.NotNil(t, result.Views[0].SelectorNoLimit) assert.True(t, *result.Views[0].SelectorNoLimit) + require.NotNil(t, result.Views[0].SelectorOrderBy) + assert.True(t, *result.Views[0].SelectorOrderBy) + assert.Equal(t, "ACCOUNT_ID", result.Views[0].SelectorOrderByColumns["accountId"]) + assert.Equal(t, "one", strings.ToLower(result.Views[0].Cardinality)) + require.NotNil(t, result.Views[0].Declaration) + assert.Equal(t, "match.go", result.Views[0].Declaration.Dest) + assert.Equal(t, "Match", result.Views[0].Declaration.TypeName) +} + +func TestApplyViewHints_MetadataCaseInsensitiveAlias(t *testing.T) { + trueValue := true + result := &plan.Result{ + Views: []*plan.View{ + {Name: "User", Holder: "User", Table: "USER"}, + }, + } + applyViewHints(result, map[string]viewHint{ + "user": { + Self: &plan.SelfReference{Holder: "Team", Child: "ID", Parent: "MGR_ID"}, + AllowNulls: &trueValue, + }, + }) + require.Len(t, result.Views, 1) + require.NotNil(t, result.Views[0].Self) + assert.Equal(t, "Team", result.Views[0].Self.Holder) + assert.Equal(t, "ID", result.Views[0].Self.Child) + assert.Equal(t, "MGR_ID", result.Views[0].Self.Parent) +} + +func TestStripProjectionHintCalls_RemovesSelfRefFromSQL(t *testing.T) { + sqlText := "SELECT user.* EXCEPT MGR_ID, self_ref(user, 'Team', 'ID', 'MGR_ID'), cardinality(user, 'one'), groupable(user), allowed_order_by_columns(user, 'id:ID') FROM (SELECT t.* FROM USER t) user" + actual := stripProjectionHintCalls(sqlText) + assert.NotContains(t, strings.ToLower(actual), "self_ref(") + assert.NotContains(t, strings.ToLower(actual), "cardinality(") + assert.NotContains(t, strings.ToLower(actual), "groupable(") + assert.NotContains(t, strings.ToLower(actual), "allowed_order_by_columns(") + assert.Contains(t, strings.ToLower(actual), "user.* except mgr_id") +} + +func TestStripProjectionHintCalls_RemovesGroupingEnabledAlias(t *testing.T) { + sqlText := "SELECT user.*, grouping_enabled(user) FROM (SELECT t.* FROM USER t) user" + actual := stripProjectionHintCalls(sqlText) + assert.NotContains(t, strings.ToLower(actual), "grouping_enabled(") + assert.Contains(t, strings.ToLower(actual), "user.*") +} + +func TestAppendRelationViews_SQLSelection(t *testing.T) { + testCases := []struct { + name string + rawDQL string + relationTable string + expectContains string + expectNotContain string + }{ + { + name: "prefers raw join subquery SQL when available", + rawDQL: ` +SELECT wrapper.*, + vendor.* +FROM (SELECT ID FROM VENDOR WHERE ID = $vendorID) wrapper +JOIN (SELECT * FROM VENDOR t WHERE t.ID = $criteria.AppendBinding($Unsafe.vendorID)) vendor ON vendor.ID = wrapper.ID`, + relationTable: "(SELECT * FROM VENDOR t WHERE t.ID = 1)", + expectContains: "$criteria.AppendBinding($Unsafe.vendorID)", + expectNotContain: "t.ID = 1", + }, + { + name: "falls back to relation table SQL when raw join SQL missing", + rawDQL: ` +SELECT wrapper.* +FROM (SELECT ID FROM VENDOR WHERE ID = $vendorID) wrapper`, + relationTable: "(SELECT * FROM VENDOR t WHERE t.ID = 1)", + expectContains: "t.ID = 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + } + root := &plan.View{ + Relations: []*plan.Relation{ + { + Name: "vendor", + Ref: "vendor", + Table: testCase.relationTable, + On: []*plan.RelationLink{ + {Expression: "vendor.ID = wrapper.ID"}, + }, + }, + }, + } + + appendRelationViews(result, root, nil, testCase.rawDQL) + require.Len(t, result.Views, 1) + assert.Contains(t, result.Views[0].SQL, testCase.expectContains) + if testCase.expectNotContain != "" { + assert.NotContains(t, result.Views[0].SQL, testCase.expectNotContain) + } + }) + } +} + +func TestAppendRelationViews_ComplexTreeAnyLevel(t *testing.T) { + rawDQL := ` +SELECT wrapper.*, + vendor.*, + products.*, + reviews.* +FROM (SELECT ID FROM VENDOR WHERE ID = $vendorID) wrapper +JOIN (SELECT * FROM VENDOR t WHERE t.ID = $criteria.AppendBinding($Unsafe.vendorID)) vendor ON vendor.ID = wrapper.ID +JOIN (SELECT * FROM PRODUCT p WHERE p.VENDOR_ID = $criteria.AppendBinding($Unsafe.vendorID)) products ON products.VENDOR_ID = vendor.ID +JOIN (SELECT * FROM REVIEW r WHERE r.PRODUCT_ID = products.ID) reviews ON reviews.PRODUCT_ID = products.ID` + + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + } + root := &plan.View{ + Relations: []*plan.Relation{ + { + Name: "vendor", + Ref: "vendor", + Parent: "wrapper", + Table: "(SELECT * FROM VENDOR t WHERE t.ID = 1)", + On: []*plan.RelationLink{ + {Expression: "vendor.ID = wrapper.ID"}, + }, + }, + { + Name: "products", + Ref: "products", + Parent: "vendor", + Table: "(SELECT * FROM PRODUCT p WHERE p.VENDOR_ID = 1)", + On: []*plan.RelationLink{ + {Expression: "products.VENDOR_ID = vendor.ID"}, + }, + }, + { + Name: "reviews", + Ref: "reviews", + Parent: "products", + Table: "(SELECT * FROM REVIEW r WHERE r.PRODUCT_ID = products.ID)", + On: []*plan.RelationLink{ + {Expression: "reviews.PRODUCT_ID = products.ID"}, + }, + }, + }, + } + + appendRelationViews(result, root, nil, rawDQL) + require.Len(t, result.Views, 3) + require.Contains(t, result.ViewsByName, "vendor") + require.Contains(t, result.ViewsByName, "products") + require.Contains(t, result.ViewsByName, "reviews") + assert.Contains(t, result.ViewsByName["vendor"].SQL, "$criteria.AppendBinding($Unsafe.vendorID)") + assert.Contains(t, result.ViewsByName["products"].SQL, "$criteria.AppendBinding($Unsafe.vendorID)") + assert.Contains(t, result.ViewsByName["reviews"].SQL, "r.PRODUCT_ID = products.ID") +} + +func TestAppendRelationViews_ComplexTreeCrossLevelJoin(t *testing.T) { + rawDQL := ` +SELECT wrapper.*, + vendor.*, + products.*, + stats.* +FROM (SELECT ID FROM VENDOR WHERE ID = $vendorID) wrapper +JOIN (SELECT * FROM VENDOR t WHERE t.ID = $criteria.AppendBinding($Unsafe.vendorID)) vendor ON vendor.ID = wrapper.ID +JOIN (SELECT * FROM PRODUCT p WHERE p.VENDOR_ID = vendor.ID) products ON products.VENDOR_ID = vendor.ID +JOIN (SELECT COUNT(1) AS CNT, v.ID AS VENDOR_ID FROM VENDOR v WHERE v.ID = wrapper.ID) stats ON stats.VENDOR_ID = wrapper.ID` + + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + } + root := &plan.View{ + Relations: []*plan.Relation{ + { + Name: "vendor", + Ref: "vendor", + Parent: "wrapper", + Table: "(SELECT * FROM VENDOR t WHERE t.ID = 1)", + On: []*plan.RelationLink{ + {Expression: "vendor.ID = wrapper.ID"}, + }, + }, + { + Name: "products", + Ref: "products", + Parent: "vendor", + Table: "(SELECT * FROM PRODUCT p WHERE p.VENDOR_ID = vendor.ID)", + On: []*plan.RelationLink{ + {Expression: "products.VENDOR_ID = vendor.ID"}, + }, + }, + { + Name: "stats", + Ref: "stats", + Parent: "products", + Table: "(SELECT COUNT(1) AS CNT, v.ID AS VENDOR_ID FROM VENDOR v WHERE v.ID = wrapper.ID)", + On: []*plan.RelationLink{ + {Expression: "stats.VENDOR_ID = wrapper.ID"}, + }, + }, + }, + } + + appendRelationViews(result, root, nil, rawDQL) + require.Len(t, result.Views, 3) + require.Contains(t, result.ViewsByName, "stats") + assert.Contains(t, result.ViewsByName["stats"].SQL, "v.ID = wrapper.ID") } diff --git a/repository/shape/compile/inline_param.go b/repository/shape/compile/inline_param.go new file mode 100644 index 000000000..af5b233a5 --- /dev/null +++ b/repository/shape/compile/inline_param.go @@ -0,0 +1,103 @@ +package compile + +import ( + "encoding/json" + "strings" + + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view/state" +) + +type inlineParamHint struct { + Kind string `json:"Kind"` + Location string `json:"Location"` + DataType string `json:"DataType"` + Required *bool `json:"Required"` +} + +// applyInlineParamHints scans the SQL text for patterns like +// $varName /* {"Kind":"header","Location":"Header-Name"} */ and updates +// matching state parameters in the plan result. +func applyInlineParamHints(sqlText string, result *plan.Result) { + if result == nil || strings.TrimSpace(sqlText) == "" { + return + } + hints := extractInlineParamHints(sqlText) + if len(hints) == 0 { + return + } + for _, st := range result.States { + if st == nil { + continue + } + name := strings.TrimPrefix(strings.TrimSpace(st.Name), "$") + hint, ok := hints[name] + if !ok { + continue + } + if hint.Kind != "" && st.In != nil { + st.In.Kind = state.Kind(strings.ToLower(hint.Kind)) + } + if hint.Location != "" && st.In != nil { + st.In.Name = hint.Location + } + if hint.DataType != "" { + ensureStateSchema(st).DataType = hint.DataType + } + if hint.Required != nil { + st.Required = hint.Required + } + } +} + +func extractInlineParamHints(sql string) map[string]inlineParamHint { + result := map[string]inlineParamHint{} + i := 0 + for i < len(sql) { + if sql[i] != '$' { + i++ + continue + } + i++ + start := i + for i < len(sql) && isParamIdentPart(sql[i]) { + i++ + } + if i == start { + continue + } + name := sql[start:i] + j := skipInlineSpaces(sql, i) + if j+1 >= len(sql) || sql[j] != '/' || sql[j+1] != '*' { + continue + } + endComment := strings.Index(sql[j+2:], "*/") + if endComment < 0 { + continue + } + body := strings.TrimSpace(sql[j+2 : j+2+endComment]) + if !strings.HasPrefix(body, "{") || !strings.HasSuffix(body, "}") { + continue + } + var hint inlineParamHint + if err := json.Unmarshal([]byte(body), &hint); err != nil { + continue + } + if hint.Kind != "" || hint.Location != "" || hint.DataType != "" { + result[name] = hint + } + i = j + 2 + endComment + 2 + } + return result +} + +func isParamIdentPart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' +} + +func skipInlineSpaces(input string, index int) int { + for index < len(input) && (input[index] == ' ' || input[index] == '\t' || input[index] == '\n' || input[index] == '\r') { + index++ + } + return index +} diff --git a/repository/shape/compile/inline_param_test.go b/repository/shape/compile/inline_param_test.go new file mode 100644 index 000000000..13079090a --- /dev/null +++ b/repository/shape/compile/inline_param_test.go @@ -0,0 +1,36 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view/state" +) + +func TestExtractInlineParamHints(t *testing.T) { + sql := `SELECT * FROM VENDOR t WHERE t.ID = $vendorID /* {"Kind": "header", "Location": "Vendor-Id"} */` + hints := extractInlineParamHints(sql) + require.Contains(t, hints, "vendorID") + assert.Equal(t, "header", hints["vendorID"].Kind) + assert.Equal(t, "Vendor-Id", hints["vendorID"].Location) +} + +func TestApplyInlineParamHints(t *testing.T) { + sql := `WHERE t.ID = $vendorID /* {"Kind": "header", "Location": "Vendor-Id"} */` + result := &plan.Result{ + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "vendorID", + In: &state.Location{Kind: state.KindQuery, Name: "vendorID"}, + }, + }, + }, + } + applyInlineParamHints(sql, result) + require.Len(t, result.States, 1) + assert.Equal(t, state.KindHeader, result.States[0].In.Kind) + assert.Equal(t, "Vendor-Id", result.States[0].In.Name) +} diff --git a/repository/shape/compile/pipeline/infer.go b/repository/shape/compile/pipeline/infer.go index 7ad212431..435b3e5d7 100644 --- a/repository/shape/compile/pipeline/infer.go +++ b/repository/shape/compile/pipeline/infer.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "unicode" "github.com/viant/sqlparser" "github.com/viant/sqlparser/query" @@ -134,6 +135,11 @@ func InferProjectionType(queryNode *query.Select) (reflect.Type, reflect.Type, s if queryNode == nil || len(queryNode.List) == 0 || queryNode.List.IsStarExpr() { return reflect.TypeOf([]map[string]interface{}{}), reflect.TypeOf(map[string]interface{}{}), "many" } + for _, item := range queryNode.List { + if requiresDeferredProjectionType(sqlparser.Stringify(item)) { + return reflect.TypeOf([]map[string]interface{}{}), reflect.TypeOf(map[string]interface{}{}), "many" + } + } fields := make([]reflect.StructField, 0, len(queryNode.List)) used := map[string]int{} for index, item := range queryNode.List { @@ -151,17 +157,47 @@ func InferProjectionType(queryNode *query.Select) (reflect.Type, reflect.Type, s } used[fieldName]++ - typ := parseColumnType(column.Type) + typ := inferColumnType(sqlparser.Stringify(item), column.Type) + veltyNames := []string{columnName} + if fieldName != "" && fieldName != columnName { + veltyNames = append(veltyNames, fieldName) + } fields = append(fields, reflect.StructField{ Name: fieldName, Type: typ, - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"name=%s"`, strings.ToLower(fieldName), columnName)), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"name=%s" velty:"names=%s"`, lowerCamel(fieldName), columnName, strings.Join(veltyNames, "|"))), }) } element := reflect.StructOf(fields) return reflect.SliceOf(element), element, "many" } +func requiresDeferredProjectionType(expression string) bool { + expression = strings.ToLower(strings.TrimSpace(expression)) + if expression == "" { + return false + } + if strings.Contains(expression, ".*") { + return true + } + if strings.Contains(expression, " except ") { + return true + } + if strings.HasPrefix(expression, "allow_nulls(") { + return true + } + return false +} + +func lowerCamel(value string) string { + if value == "" { + return "" + } + runes := []rune(value) + runes[0] = unicode.ToLower(runes[0]) + return string(runes) +} + func SanitizeName(value string) string { value = strings.TrimSpace(value) if value == "" { @@ -182,7 +218,11 @@ func SanitizeName(value string) string { } func ExportedName(value string) string { - value = replaceNonWordWithUnderscore(strings.TrimSpace(value)) + value = strings.TrimSpace(value) + if preserved := preserveMixedCaseIdentifier(value); preserved != "" { + return preserved + } + value = replaceNonWordWithUnderscore(value) value = strings.Trim(value, "_") if value == "" { return "" @@ -204,6 +244,37 @@ func ExportedName(value string) string { return name } +func preserveMixedCaseIdentifier(value string) string { + if value == "" { + return "" + } + hasLower := false + hasUpperAfterFirst := false + for i, r := range value { + if !(unicode.IsLetter(r) || unicode.IsDigit(r)) { + return "" + } + if unicode.IsLower(r) { + hasLower = true + } + if i > 0 && unicode.IsUpper(r) { + hasUpperAfterFirst = true + } + } + if !hasLower || !hasUpperAfterFirst { + return "" + } + runes := []rune(value) + if len(runes) == 0 { + return "" + } + if unicode.IsDigit(runes[0]) { + return "N" + value + } + runes[0] = unicode.ToUpper(runes[0]) + return string(runes) +} + func replaceNonWordWithUnderscore(value string) string { if value == "" { return "" @@ -233,7 +304,7 @@ func parseColumnType(dataType string) reflect.Type { return reflect.TypeOf("") case "bool", "boolean": return reflect.TypeOf(false) - case "int", "int32", "smallint", "integer": + case "int", "int32", "smallint", "integer", "signed": return reflect.TypeOf(int(0)) case "int64", "bigint": return reflect.TypeOf(int64(0)) @@ -245,3 +316,54 @@ func parseColumnType(dataType string) reflect.Type { return reflect.TypeOf("") } } + +func inferColumnType(expression, dataType string) reflect.Type { + lower := strings.ToLower(strings.TrimSpace(expression)) + switch { + case isPureAggregateProjection(lower, "count("): + return reflect.TypeOf(int(0)) + case strings.Contains(lower, " as signed"), strings.Contains(lower, " as integer"), strings.Contains(lower, " as int)"), strings.Contains(lower, " as int "): + if isComputedNumericProjection(lower) { + return reflect.TypeOf((*int)(nil)) + } + return reflect.TypeOf(int(0)) + case strings.Contains(lower, " as bigint"): + if isComputedNumericProjection(lower) { + return reflect.TypeOf((*int64)(nil)) + } + return reflect.TypeOf(int64(0)) + case strings.Contains(lower, "sum("), strings.Contains(lower, "avg("): + return reflect.TypeOf(float64(0)) + default: + dataType = strings.TrimSpace(dataType) + return parseColumnType(dataType) + } +} + +func isPureAggregateProjection(expression string, aggregate string) bool { + idx := strings.Index(expression, aggregate) + if idx == -1 { + return false + } + return strings.TrimSpace(expression[:idx]) == "" +} + +func isComputedNumericProjection(expression string) bool { + expression = strings.ToLower(strings.TrimSpace(expression)) + switch { + case strings.Contains(expression, "count("): + return !isPureAggregateProjection(expression, "count(") + case strings.Contains(expression, "sum("): + return !isPureAggregateProjection(expression, "sum(") + case strings.Contains(expression, "avg("): + return !isPureAggregateProjection(expression, "avg(") + } + return strings.Contains(expression, " + ") || + strings.Contains(expression, " - ") || + strings.Contains(expression, " * ") || + strings.Contains(expression, " / ") || + strings.Contains(expression, "case ") || + strings.Contains(expression, "coalesce(") || + strings.Contains(expression, "nullif(") || + strings.Contains(expression, "cast(") +} diff --git a/repository/shape/compile/pipeline/infer_test.go b/repository/shape/compile/pipeline/infer_test.go index 748fcded4..2ae5bb82b 100644 --- a/repository/shape/compile/pipeline/infer_test.go +++ b/repository/shape/compile/pipeline/infer_test.go @@ -1,9 +1,11 @@ package pipeline import ( + "reflect" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/viant/sqlparser" ) @@ -40,3 +42,51 @@ func TestInferTableFromSQL_ResolvesTopLevelFrom(t *testing.T) { sqlText := `SELECT a.*, EXISTS(SELECT 1 FROM CI_ENTITY_WATCHLIST w WHERE w.ENTITY_ID = a.ID) AS watching FROM (SELECT x.* FROM CI_ADVERTISER x) a` assert.Equal(t, "CI_ADVERTISER", InferTableFromSQL(sqlText)) } + +func TestExportedName_PreservesMixedCaseIdentifiers(t *testing.T) { + assert.Equal(t, "UserID", ExportedName("UserID")) + assert.Equal(t, "IsReadOnly", ExportedName("IsReadOnly")) + assert.Equal(t, "VendorName", ExportedName("vendor_name")) +} + +func TestInferProjectionType_AddsVeltyNames(t *testing.T) { + queryNode, err := sqlparser.ParseQuery(`SELECT ID, IS_AUTH FROM PRODUCT`) + require.NoError(t, err) + _, element, _ := InferProjectionType(queryNode) + require.Equal(t, reflect.Struct, element.Kind()) + field, ok := element.FieldByName("IsAuth") + assert.True(t, ok) + assert.Equal(t, `names=IS_AUTH|IsAuth`, field.Tag.Get("velty")) + idField, ok := element.FieldByName("Id") + assert.True(t, ok) + assert.Equal(t, `names=ID|Id`, idField.Tag.Get("velty")) + assert.Equal(t, "isAuth,omitempty", field.Tag.Get("json")) + assert.Equal(t, "id,omitempty", idField.Tag.Get("json")) +} + +func TestInferProjectionType_InfersSummaryExpressionTypes(t *testing.T) { + queryNode, err := sqlparser.ParseQuery(`SELECT CAST(1 + (COUNT(1) / 25) AS SIGNED) AS PAGE_CNT, COUNT(1) AS CNT FROM PRODUCT`) + require.NoError(t, err) + _, element, _ := InferProjectionType(queryNode) + require.Equal(t, reflect.Struct, element.Kind()) + + pageCnt, ok := element.FieldByName("PageCnt") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), pageCnt.Type) + assert.Equal(t, "pageCnt,omitempty", pageCnt.Tag.Get("json")) + + cnt, ok := element.FieldByName("Cnt") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf(int(0)), cnt.Type) +} + +func TestInferProjectionType_DefersWildcardAliasProjection(t *testing.T) { + queryNode, err := sqlparser.ParseQuery(`SELECT vendor.*, products.* EXCEPT VENDOR_ID, allow_nulls(products) FROM VENDOR vendor JOIN PRODUCT products ON products.VENDOR_ID = vendor.ID`) + require.NoError(t, err) + + fieldType, elementType, cardinality := InferProjectionType(queryNode) + + assert.Equal(t, reflect.TypeOf([]map[string]interface{}{}), fieldType) + assert.Equal(t, reflect.TypeOf(map[string]interface{}{}), elementType) + assert.Equal(t, "many", cardinality) +} diff --git a/repository/shape/compile/pipeline/parse.go b/repository/shape/compile/pipeline/parse.go index c897454a4..9609225cc 100644 --- a/repository/shape/compile/pipeline/parse.go +++ b/repository/shape/compile/pipeline/parse.go @@ -11,7 +11,8 @@ import ( ) func ParseSelectWithDiagnostic(sqlText string) (*query.Select, *dqlshape.Diagnostic, error) { - sqlText = trimLeadingBlockComments(sqlText) + original := sqlText + sqlText, trimPrefix := trimLeadingBlockComments(sqlText) var diagnostic *dqlshape.Diagnostic onError := func(err error, cur *parsly.Cursor, _ interface{}) error { offset := 0 @@ -26,7 +27,7 @@ func ParseSelectWithDiagnostic(sqlText string) (*query.Select, *dqlshape.Diagnos Severity: dqlshape.SeverityError, Message: strings.TrimSpace(err.Error()), Hint: "check SQL syntax near the reported location", - Span: pointSpan(sqlText, offset), + Span: pointSpan(original, offset+trimPrefix), } return err } @@ -38,7 +39,7 @@ func ParseSelectWithDiagnostic(sqlText string) (*query.Select, *dqlshape.Diagnos Severity: dqlshape.SeverityError, Message: strings.TrimSpace(err.Error()), Hint: "check SQL syntax near the reported location", - Span: pointSpan(sqlText, 0), + Span: pointSpan(original, trimPrefix), } } return nil, diagnostic, err @@ -49,14 +50,16 @@ func ParseSelectWithDiagnostic(sqlText string) (*query.Select, *dqlshape.Diagnos return result, nil, nil } -func trimLeadingBlockComments(sqlText string) string { +func trimLeadingBlockComments(sqlText string) (string, int) { remaining := strings.TrimLeft(sqlText, " \t\r\n") + trimPrefix := len(sqlText) - len(remaining) for strings.HasPrefix(remaining, "/*") { end := strings.Index(remaining, "*/") if end == -1 { - return remaining + return remaining, trimPrefix } remaining = strings.TrimLeft(remaining[end+2:], " \t\r\n") + trimPrefix = len(sqlText) - len(remaining) } - return remaining + return remaining, trimPrefix } diff --git a/repository/shape/compile/pipeline/parse_test.go b/repository/shape/compile/pipeline/parse_test.go index 69292fc8a..e0222ff28 100644 --- a/repository/shape/compile/pipeline/parse_test.go +++ b/repository/shape/compile/pipeline/parse_test.go @@ -23,7 +23,7 @@ func TestParseSelectWithDiagnostic_Syntax(t *testing.T) { require.NotNil(t, diag) assert.Equal(t, dqldiag.CodeParseSyntax, diag.Code) assert.Equal(t, 1, diag.Span.Start.Line) - assert.Greater(t, diag.Span.Start.Char, 1) + assert.Equal(t, 29, diag.Span.Start.Char) } func TestParseSelectWithDiagnostic_LeadingBlockComment(t *testing.T) { @@ -33,3 +33,43 @@ func TestParseSelectWithDiagnostic_LeadingBlockComment(t *testing.T) { require.NotNil(t, queryNode) assert.Equal(t, "o", queryNode.From.Alias) } + +func TestParseSelectWithDiagnostic_SyntaxPositionMatrix(t *testing.T) { + testCases := []struct { + name string + sql string + expectedLine int + expectedChar int + }{ + { + name: "plain sql", + sql: "SELECT id FROM orders WHERE (", + expectedLine: 1, + expectedChar: 29, + }, + { + name: "with leading block comment", + sql: "/* {\"URI\":\"/x\"} */\nSELECT id FROM orders WHERE (", + expectedLine: 2, + expectedChar: 29, + }, + { + name: "with multiple leading lines and comments", + sql: "\n\n/*a*/\n/*b*/\nSELECT id FROM orders WHERE (", + expectedLine: 5, + expectedChar: 29, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic(testCase.sql) + require.Error(t, err) + require.Nil(t, queryNode) + require.NotNil(t, diag) + assert.Equal(t, dqldiag.CodeParseSyntax, diag.Code) + assert.Equal(t, testCase.expectedLine, diag.Span.Start.Line) + assert.Equal(t, testCase.expectedChar, diag.Span.Start.Char) + }) + } +} diff --git a/repository/shape/compile/pipeline/read.go b/repository/shape/compile/pipeline/read.go index c665d154b..89ee9cf4d 100644 --- a/repository/shape/compile/pipeline/read.go +++ b/repository/shape/compile/pipeline/read.go @@ -8,8 +8,11 @@ import ( "reflect" "strings" + "github.com/viant/datly/internal/inference" dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" "github.com/viant/sqlparser/query" ) @@ -17,6 +20,14 @@ import ( // It applies multiple parse strategies and gracefully degrades to a // loose (schema-less) view for template-driven SQL that cannot be fully parsed. func BuildRead(sourceName, sqlText string) (*plan.View, []*dqlshape.Diagnostic, error) { + return BuildReadWithOptions(sourceName, sqlText, nil, nil) +} + +func BuildReadWithConsts(sourceName, sqlText string, consts map[string]string) (*plan.View, []*dqlshape.Diagnostic, error) { + return BuildReadWithOptions(sourceName, sqlText, consts, nil) +} + +func BuildReadWithOptions(sourceName, sqlText string, consts map[string]string, groupableAliases map[string]bool) (*plan.View, []*dqlshape.Diagnostic, error) { queryNode, parseDiag, parserSQL, err := resolveQueryNode(sqlText) // Template-driven SQL may legitimately fail strict parsing; treat as warning. @@ -24,7 +35,9 @@ func BuildRead(sourceName, sqlText string) (*plan.View, []*dqlshape.Diagnostic, if parseDiag != nil { parseDiag.Severity = dqlshape.SeverityWarning } - return buildLooseRead(sourceName, sqlText), collectDiags(parseDiag), nil + view := buildLooseRead(sourceName, sqlText) + applyConstTables(view, consts) + return view, collectDiags(parseDiag), nil } var diags []*dqlshape.Diagnostic @@ -56,21 +69,121 @@ func BuildRead(sourceName, sqlText string) (*plan.View, []*dqlshape.Diagnostic, elementType = reflect.TypeOf(map[string]interface{}{}) cardinality = "many" } + rootSQL := sqlText + if rawRoot := extractRootSQLFromRaw(sqlText); rawRoot != "" { + rootSQL = rawRoot + } else if queryNode.From.X != nil { + fromExpr := strings.TrimSpace(sqlparser.Stringify(queryNode.From.X)) + fromExpr = trimJoinSuffix(fromExpr) + candidate := extractParenthesizedSelect(fromExpr) + if candidate == "" { + candidate = unwrapReadParens(fromExpr) + } + if candidate != "" { + rootSQL = candidate + } + } view := &plan.View{ Path: name, Holder: name, Name: name, Mode: "SQLQuery", Table: table, - SQL: sqlText, + SQL: rootSQL, Cardinality: cardinality, FieldType: fieldType, ElementType: elementType, Relations: relations, } + exceptByAlias := extractExceptColumnsByNamespace(queryNode) + groupableByAlias := extractGroupableColumnsByNamespace(queryNode, name, groupableAliases) + rootConfig := mergeColumnConfigs( + lookupExceptColumns(exceptByAlias, name), + lookupColumnConfigs(groupableByAlias, name), + extractRootGroupedColumnConfigs(rootSQL, name, groupableAliases), + ) + if len(rootConfig) > 0 { + view.Declaration = &plan.ViewDeclaration{ColumnsConfig: rootConfig} + } + applyRelationExceptColumns(relations, exceptByAlias) + applyRelationGroupableColumns(relations, groupableByAlias) + applyConstTables(view, consts) return view, diags, nil } +func applyConstTables(view *plan.View, consts map[string]string) { + if view == nil || len(consts) == 0 { + return + } + view.Table = resolveConstTable(view.Table, consts) + for _, relation := range view.Relations { + if relation == nil { + continue + } + relation.Table = resolveConstTable(relation.Table, consts) + } +} + +func resolveConstTable(table string, consts map[string]string) string { + trimmed := strings.TrimSpace(strings.Trim(table, "`\"")) + if token := unsafeSelectorToken(trimmed); token != "" { + if resolved := resolveConstValue(token, consts); resolved != "" { + return resolved + } + } + for key, value := range consts { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + continue + } + placeholder := "Unsafe_" + key + if strings.EqualFold(trimmed, placeholder) { + return value + } + templatePlaceholder := "${Unsafe." + key + "}" + if strings.EqualFold(trimmed, templatePlaceholder) { + return value + } + selectorPlaceholder := "$Unsafe." + key + if strings.EqualFold(trimmed, selectorPlaceholder) { + return value + } + if strings.Contains(table, placeholder) { + table = strings.ReplaceAll(table, placeholder, value) + } + if strings.Contains(table, templatePlaceholder) { + table = strings.ReplaceAll(table, templatePlaceholder, value) + } + if strings.Contains(table, selectorPlaceholder) { + table = strings.ReplaceAll(table, selectorPlaceholder, value) + } + } + return table +} + +func unsafeSelectorToken(input string) string { + if strings.HasPrefix(input, "${Unsafe.") && strings.HasSuffix(input, "}") { + return strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(input, "${Unsafe."), "}")) + } + if strings.HasPrefix(input, "$Unsafe.") { + return strings.TrimSpace(strings.TrimPrefix(input, "$Unsafe.")) + } + return "" +} + +func resolveConstValue(token string, consts map[string]string) string { + if token == "" || len(consts) == 0 { + return "" + } + for key, value := range consts { + if strings.EqualFold(strings.TrimSpace(key), token) && strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + // resolveQueryNode attempts to parse sqlText into a query AST using up to // three strategies: // 1. Parse the normalised form. @@ -170,6 +283,206 @@ func inferRootFromRelations(relations []*plan.Relation) string { return "" } +func extractRootExceptColumns(queryNode *query.Select, rootName string) map[string]*plan.ViewColumnConfig { + return lookupExceptColumns(extractExceptColumnsByNamespace(queryNode), rootName) +} + +func applyRelationExceptColumns(relations []*plan.Relation, exceptByAlias map[string]map[string]*plan.ViewColumnConfig) { + if len(relations) == 0 || len(exceptByAlias) == 0 { + return + } + for _, relation := range relations { + if relation == nil { + continue + } + if columns := lookupExceptColumns(exceptByAlias, relation.Ref); len(columns) > 0 { + relation.ColumnsConfig = columns + } + } +} + +func applyRelationGroupableColumns(relations []*plan.Relation, groupableByAlias map[string]map[string]*plan.ViewColumnConfig) { + if len(relations) == 0 || len(groupableByAlias) == 0 { + return + } + for _, relation := range relations { + if relation == nil { + continue + } + relation.ColumnsConfig = mergeColumnConfigs(relation.ColumnsConfig, lookupColumnConfigs(groupableByAlias, relation.Ref)) + } +} + +func lookupExceptColumns(exceptByAlias map[string]map[string]*plan.ViewColumnConfig, alias string) map[string]*plan.ViewColumnConfig { + return lookupColumnConfigs(exceptByAlias, alias) +} + +func lookupColumnConfigs(byAlias map[string]map[string]*plan.ViewColumnConfig, alias string) map[string]*plan.ViewColumnConfig { + if len(byAlias) == 0 { + return nil + } + alias = strings.ToLower(strings.TrimSpace(alias)) + if alias == "" { + return nil + } + result := byAlias[alias] + if len(result) == 0 { + return nil + } + ret := make(map[string]*plan.ViewColumnConfig, len(result)) + for key, cfg := range result { + if cfg == nil { + continue + } + cloned := *cfg + if cfg.Groupable != nil { + value := *cfg.Groupable + cloned.Groupable = &value + } + ret[key] = &cloned + } + if len(ret) == 0 { + return nil + } + return ret +} + +func mergeColumnConfigs(base map[string]*plan.ViewColumnConfig, overlays ...map[string]*plan.ViewColumnConfig) map[string]*plan.ViewColumnConfig { + var result map[string]*plan.ViewColumnConfig + if len(base) > 0 { + result = lookupColumnConfigs(map[string]map[string]*plan.ViewColumnConfig{"_": base}, "_") + } + for _, overlay := range overlays { + for name, cfg := range overlay { + name = strings.TrimSpace(name) + if name == "" || cfg == nil { + continue + } + if result == nil { + result = map[string]*plan.ViewColumnConfig{} + } + target := result[name] + if target == nil { + target = &plan.ViewColumnConfig{} + result[name] = target + } + if dataType := strings.TrimSpace(cfg.DataType); dataType != "" && target.DataType == "" { + target.DataType = dataType + } + if tag := strings.TrimSpace(cfg.Tag); tag != "" && target.Tag == "" { + target.Tag = tag + } + if target.Groupable == nil && cfg.Groupable != nil { + value := *cfg.Groupable + target.Groupable = &value + } + } + } + if len(result) == 0 { + return nil + } + return result +} + +func extractGroupableColumnsByNamespace(queryNode *query.Select, rootName string, enabledAliases map[string]bool) map[string]map[string]*plan.ViewColumnConfig { + if queryNode == nil || len(enabledAliases) == 0 { + return nil + } + columns := sqlparser.NewColumns(queryNode.List) + groupable := inference.GroupableColumns(queryNode, columns) + if len(groupable) == 0 { + return nil + } + rootName = strings.ToLower(strings.TrimSpace(rootName)) + result := map[string]map[string]*plan.ViewColumnConfig{} + for _, column := range columns { + if column == nil || !groupable[column.Identity()] { + continue + } + name := strings.TrimSpace(column.Identity()) + if name == "" { + continue + } + namespace := strings.ToLower(strings.TrimSpace(column.Namespace)) + if namespace == "" { + namespace = rootName + } + if namespace == "" || !enabledAliases[namespace] { + continue + } + columnsConfig := result[namespace] + if columnsConfig == nil { + columnsConfig = map[string]*plan.ViewColumnConfig{} + result[namespace] = columnsConfig + } + if columnsConfig[name] == nil { + value := true + columnsConfig[name] = &plan.ViewColumnConfig{Groupable: &value} + } + } + if len(result) == 0 { + return nil + } + return result +} + +func extractRootGroupedColumnConfigs(sqlText, rootName string, enabledAliases map[string]bool) map[string]*plan.ViewColumnConfig { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return nil + } + rootName = strings.ToLower(strings.TrimSpace(rootName)) + if len(enabledAliases) == 0 || !enabledAliases[rootName] { + return nil + } + queryNode, _, _, err := resolveQueryNode(sqlText) + if err != nil || queryNode == nil { + return nil + } + return lookupColumnConfigs(extractGroupableColumnsByNamespace(queryNode, rootName, enabledAliases), rootName) +} + +func extractExceptColumnsByNamespace(queryNode *query.Select) map[string]map[string]*plan.ViewColumnConfig { + if queryNode == nil { + return nil + } + result := map[string]map[string]*plan.ViewColumnConfig{} + for _, item := range queryNode.List { + if item == nil || item.Expr == nil { + continue + } + star, ok := item.Expr.(*expr.Star) + if !ok || len(star.Except) == 0 { + continue + } + selectorNs := "" + if selector, ok := star.X.(*expr.Selector); ok { + selectorNs = strings.ToLower(strings.TrimSpace(selector.Name)) + } + if selectorNs == "" { + continue + } + nsColumns := result[selectorNs] + if nsColumns == nil { + nsColumns = map[string]*plan.ViewColumnConfig{} + result[selectorNs] = nsColumns + } + for _, exceptColumn := range star.Except { + exceptColumn = strings.TrimSpace(exceptColumn) + if exceptColumn == "" { + continue + } + nsColumns[exceptColumn] = &plan.ViewColumnConfig{ + Tag: `internal:"true"`, + } + } + } + if len(result) == 0 { + return nil + } + return result +} + func extractSimpleFromTable(sqlText string) string { lower := strings.ToLower(sqlText) for i := 0; i+4 <= len(lower); i++ { @@ -195,6 +508,227 @@ func extractSimpleFromTable(sqlText string) string { return "" } +func extractRootSQLFromRaw(sqlText string) string { + if strings.TrimSpace(sqlText) == "" { + return "" + } + lower := strings.ToLower(sqlText) + depth := 0 + quote := byte(0) + for i := 0; i < len(sqlText); i++ { + ch := sqlText[i] + if quote != 0 { + if ch == '\\' && i+1 < len(sqlText) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + switch ch { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + } + if depth != 0 || !hasReadWordAt(lower, i, "from") { + continue + } + start := skipReadSpaces(sqlText, i+4) + if start >= len(sqlText) { + return "" + } + fromExpr := trimJoinSuffix(sqlText[start:]) + fromExpr = strings.TrimSpace(fromExpr) + if fromExpr == "" { + return "" + } + if candidate := extractParenthesizedSelect(fromExpr); candidate != "" { + return candidate + } + fromExpr = unwrapReadParens(fromExpr) + fromExpr = strings.TrimSpace(fromExpr) + if fromExpr == "" { + return "" + } + if strings.HasPrefix(strings.ToLower(fromExpr), "select ") { + return fromExpr + } + return "SELECT * FROM " + fromExpr + } + return "" +} + +func unwrapReadParens(input string) string { + input = strings.TrimSpace(input) + if len(input) < 2 || input[0] != '(' || input[len(input)-1] != ')' { + return input + } + depth := 0 + quote := byte(0) + for i := 0; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + switch ch { + case '(': + depth++ + case ')': + depth-- + if depth == 0 && i != len(input)-1 { + return input + } + } + } + if depth != 0 { + return input + } + inner := strings.TrimSpace(input[1 : len(input)-1]) + if inner == "" { + return input + } + return inner +} + +func trimJoinSuffix(input string) string { + input = strings.TrimSpace(input) + if input == "" { + return "" + } + depth := 0 + quote := byte(0) + for i := 0; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + switch ch { + case '(': + depth++ + case ')': + if depth > 0 { + depth-- + } + } + if depth != 0 || !isReadWordStart(input[i]) { + continue + } + if hasReadWordAt(strings.ToLower(input), i, "join") { + return strings.TrimSpace(input[:i]) + } + } + return input +} + +func extractParenthesizedSelect(input string) string { + input = strings.TrimSpace(input) + if input == "" || input[0] != '(' { + return "" + } + body, end, ok := readReadParenBody(input, 0) + if !ok { + return "" + } + tail := strings.TrimSpace(input[end+1:]) + if tail != "" && !isReadIdentifierStart(tail[0]) { + return "" + } + body = strings.TrimSpace(body) + if strings.HasPrefix(strings.ToLower(body), "select ") { + return body + } + return "" +} + +func readReadParenBody(input string, openParen int) (string, int, bool) { + depth := 0 + quote := byte(0) + for i := openParen; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return input[openParen+1 : i], i, true + } + } + } + return "", -1, false +} + +func hasReadWordAt(lower string, pos int, word string) bool { + if pos < 0 || pos+len(word) > len(lower) { + return false + } + if lower[pos:pos+len(word)] != word { + return false + } + if pos > 0 && isReadWordPart(lower[pos-1]) { + return false + } + next := pos + len(word) + if next < len(lower) && isReadWordPart(lower[next]) { + return false + } + return true +} + +func isReadWordStart(ch byte) bool { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' +} + +func isReadWordPart(ch byte) bool { + return isReadWordStart(ch) || (ch >= '0' && ch <= '9') +} + // collectDiags returns a single-element slice for a non-nil diagnostic, // or nil otherwise. Used to avoid repeated nil checks at call sites. func collectDiags(diag *dqlshape.Diagnostic) []*dqlshape.Diagnostic { diff --git a/repository/shape/compile/pipeline/read_normalize.go b/repository/shape/compile/pipeline/read_normalize.go index 6ff42af2f..826cdcc77 100644 --- a/repository/shape/compile/pipeline/read_normalize.go +++ b/repository/shape/compile/pipeline/read_normalize.go @@ -14,6 +14,11 @@ func normalizeParserSQL(sqlText string) string { return rewritePrivateShorthand(replaceTemplateTokens(sqlText)) } +// NormalizeParserSQL exposes the parser-safe SQL normalization used by read compilation. +func NormalizeParserSQL(sqlText string) string { + return normalizeParserSQL(sqlText) +} + func rewritePrivateShorthand(input string) string { var b strings.Builder b.Grow(len(input)) @@ -154,6 +159,9 @@ func normalizeTemplateExprBody(body string) (string, bool) { if isReadReservedName(trimmed) { return "", true } + if selector := normalizeTemplateSelector(trimmed); selector != "" { + return selector, false + } lower := strings.ToLower(trimmed) if strings.Contains(lower, `build("where")`) || strings.Contains(lower, "build('where')") { return " WHERE 1 ", false @@ -164,6 +172,38 @@ func normalizeTemplateExprBody(body string) (string, bool) { return "1", false } +func normalizeTemplateSelector(input string) string { + if input == "" { + return "" + } + for i := 0; i < len(input); i++ { + ch := input[i] + if !(isReadIdentifierPart(ch) || ch == '.') { + return "" + } + } + parts := strings.Split(input, ".") + builder := strings.Builder{} + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if builder.Len() > 0 { + builder.WriteByte('_') + } + builder.WriteString(part) + } + result := builder.String() + if result == "" { + return "" + } + if !isReadIdentifierStart(result[0]) { + return "" + } + return result +} + func readReadTemplateExpr(input string, openBrace int) (string, int, bool) { if openBrace <= 0 || openBrace >= len(input) || input[openBrace] != '{' || input[openBrace-1] != '$' { return "", -1, false diff --git a/repository/shape/compile/pipeline/read_test.go b/repository/shape/compile/pipeline/read_test.go index 9d414beb8..87fe8aec8 100644 --- a/repository/shape/compile/pipeline/read_test.go +++ b/repository/shape/compile/pipeline/read_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" "github.com/viant/sqlparser/expr" "github.com/viant/sqlparser/query" ) @@ -31,10 +32,26 @@ JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session require.NotNil(t, view) assert.Equal(t, "session", view.Name) assert.Equal(t, "session", view.Table) + assert.Contains(t, view.SQL, "$criteria.AppendBinding($Unsafe.Jwt.UserID)") require.NotEmpty(t, view.Relations) assert.Equal(t, "attribute", view.Relations[0].Ref) } +func TestExtractRootSQLFromRaw_JoinRootTable(t *testing.T) { + sqlText := "SELECT o.id, i.sku FROM orders o JOIN items i ON o.id = i.order_id" + assert.Equal(t, "SELECT * FROM orders o", extractRootSQLFromRaw(sqlText)) +} + +func TestExtractRootSQLFromRaw_PreservesTemplateVariables(t *testing.T) { + sqlText := `SELECT wrapper.* EXCEPT ID, + vendor.* +FROM (SELECT ID FROM VENDOR WHERE ID = $vendorID ) wrapper +JOIN (SELECT * FROM VENDOR t WHERE t.ID = $vendorID ) vendor ON vendor.ID = wrapper.ID` + root := extractRootSQLFromRaw(sqlText) + assert.Contains(t, root, "$vendorID") + assert.NotContains(t, root, " ID = 1 ") +} + func TestNormalizeParserSQL(t *testing.T) { input := "SELECT * FROM session WHERE user_id = $criteria.AppendBinding($Unsafe.Jwt.UserID) AND x = $Jwt.UserID" actual := normalizeParserSQL(input) @@ -50,6 +67,13 @@ func TestNormalizeParserSQL_VeltyBlockExpression(t *testing.T) { assert.Contains(t, actual, "SELECT b.* FROM CI_BROWSER b WHERE 1 AND b.ARCHIVED = 0") } +func TestNormalizeParserSQL_TemplateSelector(t *testing.T) { + input := `SELECT * FROM ${Unsafe.Vendor} t WHERE t.ID = ${Unsafe.VendorID}` + actual := normalizeParserSQL(input) + assert.Contains(t, actual, "FROM Unsafe_Vendor t") + assert.Contains(t, actual, "t.ID = Unsafe_VendorID") +} + func TestNormalizeParserSQL_PrivateShorthand(t *testing.T) { input := `SELECT private(audience.FREQ_CAPPING) AS freq_capping FROM CI_AUDIENCE audience` actual := normalizeParserSQL(input) @@ -71,3 +95,199 @@ func TestBuildRead_FallbackWhenInitialParseFails(t *testing.T) { assert.Equal(t, "CI_BROWSER", view.Table) assert.Empty(t, diags) } + +func TestBuildRead_NoJoin_UsesFromSourceSQL(t *testing.T) { + sqlText := `SELECT user.* EXCEPT MGR_ID, self_ref(user, 'Team', 'ID', 'MGR_ID') FROM (SELECT t.* FROM USER t) user` + view, _, err := BuildRead("user_tree", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "SELECT t.* FROM USER t", strings.TrimSpace(view.SQL)) +} + +func TestBuildRead_ExceptBecomesInternalColumnConfig(t *testing.T) { + sqlText := `SELECT user.* EXCEPT MGR_ID FROM (SELECT t.* FROM USER t) user` + view, _, err := BuildRead("user_tree", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + require.NotNil(t, view.Declaration) + require.NotNil(t, view.Declaration.ColumnsConfig) + cfg, ok := view.Declaration.ColumnsConfig["MGR_ID"] + require.True(t, ok) + require.NotNil(t, cfg) + assert.Equal(t, `internal:"true"`, cfg.Tag) +} + +func TestBuildRead_ChildExceptBecomesRelationColumnConfig(t *testing.T) { + sqlText := `SELECT wrapper.* EXCEPT ID, + products.* EXCEPT VENDOR_ID, + setting.* EXCEPT ID +FROM (SELECT ID FROM VENDOR WHERE ID = $VendorID) wrapper +JOIN (SELECT * FROM (SELECT (1) AS IS_ACTIVE, (3) AS CHANNEL, CAST($VendorID AS SIGNED) AS ID) t) setting ON setting.ID = wrapper.ID +JOIN (SELECT * FROM PRODUCT t) products ON products.VENDOR_ID = wrapper.ID` + view, _, err := BuildRead("vendor_details", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + require.Len(t, view.Relations, 2) + + var productsCfg, settingCfg map[string]*plan.ViewColumnConfig + for _, rel := range view.Relations { + switch rel.Ref { + case "products": + productsCfg = rel.ColumnsConfig + case "setting": + settingCfg = rel.ColumnsConfig + } + } + require.Contains(t, productsCfg, "VENDOR_ID") + assert.Equal(t, `internal:"true"`, productsCfg["VENDOR_ID"].Tag) + require.Contains(t, settingCfg, "ID") + assert.Equal(t, `internal:"true"`, settingCfg["ID"].Tag) +} + +func TestBuildRead_GroupByDoesNotMarkColumnsWithoutExplicitGrouping(t *testing.T) { + sqlText := `SELECT t.REGION AS REGION, COUNT(*) AS TOTAL FROM SALES t GROUP BY REGION` + view, _, err := BuildRead("sales_report", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + if view.Declaration != nil { + assert.Empty(t, view.Declaration.ColumnsConfig) + } +} + +func TestBuildRead_GroupByMarksRootGroupedColumnsWithExplicitGrouping(t *testing.T) { + sqlText := `SELECT t.REGION AS REGION, COUNT(*) AS TOTAL FROM SALES t GROUP BY REGION` + view, _, err := BuildReadWithOptions("sales_report", sqlText, nil, map[string]bool{"t": true}) + require.NoError(t, err) + require.NotNil(t, view) + require.NotNil(t, view.Declaration) + require.NotNil(t, view.Declaration.ColumnsConfig) + cfg, ok := view.Declaration.ColumnsConfig["REGION"] + require.True(t, ok) + require.NotNil(t, cfg) + require.NotNil(t, cfg.Groupable) + assert.True(t, *cfg.Groupable) + _, ok = view.Declaration.ColumnsConfig["TOTAL"] + assert.False(t, ok) +} + +func TestBuildRead_GroupByMarksRelationGroupedColumnsWithExplicitGrouping(t *testing.T) { + sqlText := `SELECT vendor.REGION AS REGION, products.CATEGORY AS CATEGORY, COUNT(*) AS TOTAL +FROM VENDOR vendor +JOIN PRODUCT products ON products.VENDOR_ID = vendor.ID +GROUP BY vendor.REGION, products.CATEGORY` + view, _, err := BuildReadWithOptions("vendor_products", sqlText, nil, map[string]bool{"vendor": true, "products": true}) + require.NoError(t, err) + require.NotNil(t, view) + require.NotNil(t, view.Declaration) + require.Contains(t, view.Declaration.ColumnsConfig, "REGION") + require.NotNil(t, view.Declaration.ColumnsConfig["REGION"].Groupable) + assert.True(t, *view.Declaration.ColumnsConfig["REGION"].Groupable) + require.Len(t, view.Relations, 1) + require.Contains(t, view.Relations[0].ColumnsConfig, "CATEGORY") + require.NotNil(t, view.Relations[0].ColumnsConfig["CATEGORY"].Groupable) + assert.True(t, *view.Relations[0].ColumnsConfig["CATEGORY"].Groupable) + _, ok := view.Relations[0].ColumnsConfig["TOTAL"] + assert.False(t, ok) +} + +func TestBuildRead_GroupByInRootSubqueryMarksGroupedColumnsWithExplicitGrouping(t *testing.T) { + sqlText := `SELECT vendor.* +FROM ( + SELECT ACCOUNT_ID, + USER_CREATED, + SUM(ID) AS TOTAL_ID, + MAX(ID) AS MAX_ID + FROM VENDOR t + GROUP BY 1, 2 +) vendor` + view, _, err := BuildReadWithOptions("vendors_grouping", sqlText, nil, map[string]bool{"vendor": true}) + require.NoError(t, err) + require.NotNil(t, view) + require.NotNil(t, view.Declaration) + require.NotNil(t, view.Declaration.ColumnsConfig) + require.Contains(t, view.Declaration.ColumnsConfig, "ACCOUNT_ID") + require.NotNil(t, view.Declaration.ColumnsConfig["ACCOUNT_ID"].Groupable) + assert.True(t, *view.Declaration.ColumnsConfig["ACCOUNT_ID"].Groupable) + require.Contains(t, view.Declaration.ColumnsConfig, "USER_CREATED") + require.NotNil(t, view.Declaration.ColumnsConfig["USER_CREATED"].Groupable) + assert.True(t, *view.Declaration.ColumnsConfig["USER_CREATED"].Groupable) + _, ok := view.Declaration.ColumnsConfig["TOTAL_ID"] + assert.False(t, ok) +} + +func TestBuildRead_GroupByWithQualifiedColumnsAndTemplatePredicateMarksPublisherID(t *testing.T) { + sqlText := `SELECT + p.event_date, + p.agency_id, + p.advertiser_id, + p.campaign_id, + p.ad_order_id, + p.audience_id, + p.deal_id, + p.publisher_id, + p.channel_id, + p.country, + p.site_type, + SUM(p.bids) AS bids, + SUM(p.impressions) AS impressions, + SUM(p.clicks) AS clicks, + SUM(p.conversions) AS conversions, + SUM(p.total_spend) AS total_spend +FROM + ` + "`viant-mediator.forecaster.fact_perf_daily_mv`" + ` p +WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL $DateInterval DAY) +AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) + ${predicate.Builder().CombineOr($predicate.FilterGroup(0, "AND")).Build("AND")} +GROUP BY + p.event_date, + p.agency_id, + p.advertiser_id, + p.campaign_id, + p.ad_order_id, + p.audience_id, + p.deal_id, + p.publisher_id, + p.channel_id, + p.country, + p.site_type` + view, diags, err := BuildReadWithOptions("fact_perf_daily_mv", sqlText, nil, map[string]bool{"p": true}) + require.NoError(t, err) + require.NotNil(t, view) + require.Empty(t, diags) + require.NotNil(t, view.Declaration) + require.NotNil(t, view.Declaration.ColumnsConfig) + cfg, ok := view.Declaration.ColumnsConfig["publisher_id"] + require.True(t, ok) + require.NotNil(t, cfg) + require.NotNil(t, cfg.Groupable) + assert.True(t, *cfg.Groupable) + _, ok = view.Declaration.ColumnsConfig["total_spend"] + assert.False(t, ok) +} + +func TestBuildRead_TemplateTableSelector_PreservesRelations(t *testing.T) { + sqlText := `SELECT vendor.*, products.* +FROM (SELECT * FROM ${Unsafe.Vendor} t WHERE t.ID IN ($criteria.AppendBinding($Unsafe.vendorIDs))) vendor +JOIN (SELECT * FROM ${Unsafe.Product} t) products ON products.VENDOR_ID = vendor.ID` + view, _, err := BuildRead("const", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "vendor", view.Name) + require.NotEmpty(t, view.Relations) + assert.Equal(t, "products", view.Relations[0].Ref) +} + +func TestBuildReadWithConsts_ResolvesUnsafeTablePlaceholders(t *testing.T) { + sqlText := `SELECT vendor.*, products.* +FROM (SELECT * FROM ${Unsafe.Vendor} t WHERE t.ID IN ($criteria.AppendBinding($Unsafe.vendorIDs))) vendor +JOIN (SELECT * FROM ${Unsafe.Product} t) products ON products.VENDOR_ID = vendor.ID` + view, _, err := BuildReadWithConsts("const", sqlText, map[string]string{ + "Vendor": "VENDOR", + "Product": "PRODUCT", + }) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "VENDOR", view.Table) + require.NotEmpty(t, view.Relations) + assert.Contains(t, view.Relations[0].Table, "PRODUCT") +} diff --git a/repository/shape/compile/pipeline/relation.go b/repository/shape/compile/pipeline/relation.go index dc94b8955..d60b56816 100644 --- a/repository/shape/compile/pipeline/relation.go +++ b/repository/shape/compile/pipeline/relation.go @@ -30,6 +30,7 @@ func ExtractJoinRelations(raw string, queryNode *query.Select) ([]*plan.Relation ref, table := relationRef(join, idx+1) relation := &plan.Relation{ Name: ref, + Parent: relationParentAlias(join, rootAlias), Holder: ExportedName(ref), Ref: ref, Table: table, @@ -346,6 +347,31 @@ func rootNamespace(queryNode *query.Select) string { return root } +func relationParentAlias(join *query.Join, rootAlias string) string { + if join == nil || join.On == nil { + return strings.TrimSpace(rootAlias) + } + parent := "" + sqlparser.Traverse(join.On, func(n node.Node) bool { + selector, ok := n.(*expr.Selector) + if !ok { + return true + } + name := strings.TrimSpace(selector.Name) + if name == "" || strings.EqualFold(name, strings.TrimSpace(join.Alias)) { + return true + } + if parent == "" { + parent = name + } + return true + }) + if parent != "" { + return parent + } + return strings.TrimSpace(rootAlias) +} + func relationRef(join *query.Join, ordinal int) (string, string) { if join == nil { return fmt.Sprintf("join_%d", ordinal), "" diff --git a/repository/shape/compile/pipeline/relation_test.go b/repository/shape/compile/pipeline/relation_test.go index 62f5c9d3a..dd358aadb 100644 --- a/repository/shape/compile/pipeline/relation_test.go +++ b/repository/shape/compile/pipeline/relation_test.go @@ -49,6 +49,9 @@ func TestExtractJoinRelations_NonRootParentChain(t *testing.T) { require.NoError(t, err) relations, diags := ExtractJoinRelations(sqlText, queryNode) require.Len(t, relations, 3) + assert.Equal(t, "sl", relations[0].Parent) + assert.Equal(t, "m", relations[1].Parent) + assert.Equal(t, "s", relations[2].Parent) require.Len(t, relations[0].On, 1) assert.Equal(t, "sl", relations[0].On[0].ParentNamespace) @@ -70,6 +73,56 @@ func TestExtractJoinRelations_NonRootParentChain(t *testing.T) { assert.Empty(t, diags) } +func TestExtractJoinRelations_ParentAliasMatrix(t *testing.T) { + testCases := []struct { + name string + sqlText string + expected map[string]string + }{ + { + name: "root parent", + sqlText: "SELECT o.id FROM orders o JOIN order_items i ON o.id = i.order_id", + expected: map[string]string{ + "i": "o", + }, + }, + { + name: "multi level chain", + sqlText: "SELECT sl.id FROM site_list sl JOIN site_list_match m ON m.site_list_id = sl.id JOIN ci_site s ON s.id = m.site_id JOIN ci_publisher p ON p.id = s.publisher_id", + expected: map[string]string{ + "m": "sl", + "s": "m", + "p": "s", + }, + }, + { + name: "left join child of child", + sqlText: "SELECT a.id FROM alpha a LEFT JOIN beta b ON b.a_id = a.id LEFT JOIN gamma g ON g.b_id = b.id", + expected: map[string]string{ + "b": "a", + "g": "b", + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + queryNode, err := sqlparser.ParseQuery(testCase.sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(testCase.sqlText, queryNode) + assert.Empty(t, diags) + got := map[string]string{} + for _, relation := range relations { + if relation == nil { + continue + } + got[relation.Ref] = relation.Parent + } + assert.Equal(t, testCase.expected, got) + }) + } +} + func TestExtractJoinRelations_DoesNotFallbackForComplexRawPredicate(t *testing.T) { sqlText := "SELECT o.id FROM orders o JOIN order_items i ON COALESCE(o.id, 0) = i.order_id" queryNode, err := sqlparser.ParseQuery(sqlText) diff --git a/repository/shape/compile/resolver_test.go b/repository/shape/compile/resolver_test.go new file mode 100644 index 000000000..ab88a3fe8 --- /dev/null +++ b/repository/shape/compile/resolver_test.go @@ -0,0 +1,17 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitRouteKey(t *testing.T) { + method, uri := splitRouteKey("POST:/v1/api/platform/acl/auth") + assert.Equal(t, "POST", method) + assert.Equal(t, "/v1/api/platform/acl/auth", uri) + + method, uri = splitRouteKey("/v1/api/platform/acl/auth") + assert.Equal(t, "GET", method) + assert.Equal(t, "/v1/api/platform/acl/auth", uri) +} diff --git a/repository/shape/compile/route_index_test.go b/repository/shape/compile/route_index_test.go new file mode 100644 index 000000000..c93454875 --- /dev/null +++ b/repository/shape/compile/route_index_test.go @@ -0,0 +1,71 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildRouteIndex_AndResolve(t *testing.T) { + tempDir := t.TempDir() + authPath := filepath.Join(tempDir, "dql", "platform", "acl", "auth.dql") + reportPath := filepath.Join(tempDir, "dql", "platform", "reports", "orders", "orders.dql") + require.NoError(t, writeFile(authPath, `/* {"URI":"/v1/api/platform/acl/auth","Method":"GET"} */ SELECT 1`)) + require.NoError(t, writeFile(reportPath, `SELECT 1`)) + + index, err := BuildRouteIndex([]string{authPath, reportPath}) + require.NoError(t, err) + + _, ok := index.ByRouteKey["GET:/v1/api/platform/acl/auth"] + assert.True(t, ok) + _, ok = index.ByRouteKey["GET:/v1/api/platform/reports/orders"] + assert.True(t, ok) // inferred from namespace when URI is not explicitly declared + + resolved, ok := index.Resolve("../../acl/auth", reportPath) + require.True(t, ok) + assert.Equal(t, "GET:/v1/api/platform/acl/auth", resolved) +} + +func TestRouteIndex_ResolveByAbsoluteURI(t *testing.T) { + tempDir := t.TempDir() + authPath := filepath.Join(tempDir, "dql", "platform", "acl", "auth.dql") + require.NoError(t, writeFile(authPath, `/* {"URI":"/v1/api/platform/acl/auth","Method":"POST"} */ SELECT 1`)) + + index, err := BuildRouteIndex([]string{authPath}) + require.NoError(t, err) + + resolved, ok := index.Resolve("POST:/v1/api/platform/acl/auth", authPath) + require.True(t, ok) + assert.Equal(t, "POST:/v1/api/platform/acl/auth", resolved) +} + +func TestBuildRouteIndex_Conflicts(t *testing.T) { + tempDir := t.TempDir() + leftPath := filepath.Join(tempDir, "dql", "platform", "left", "x.dql") + rightPath := filepath.Join(tempDir, "dql", "platform", "right", "y.dql") + content := `/* {"URI":"/v1/api/platform/shared/resource","Method":"GET"} */ SELECT 1` + require.NoError(t, writeFile(leftPath, content)) + require.NoError(t, writeFile(rightPath, content)) + + index, err := BuildRouteIndex([]string{leftPath, rightPath}) + require.NoError(t, err) + + conflicts := index.Conflicts["GET:/v1/api/platform/shared/resource"] + require.Len(t, conflicts, 2) + _, ok := index.Resolve("GET:/v1/api/platform/shared/resource", leftPath) + assert.False(t, ok) +} + +func writeFile(path, content string) error { + if err := ensureDir(filepath.Dir(path)); err != nil { + return err + } + return os.WriteFile(path, []byte(content), 0o644) +} + +func ensureDir(path string) error { + return os.MkdirAll(path, 0o755) +} diff --git a/repository/shape/compile/statedecl.go b/repository/shape/compile/statedecl.go index bd401b11d..7e1a5bde8 100644 --- a/repository/shape/compile/statedecl.go +++ b/repository/shape/compile/statedecl.go @@ -1,9 +1,13 @@ package compile import ( + "fmt" "strconv" "strings" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/view/extension" st "github.com/viant/datly/view/state" @@ -14,25 +18,28 @@ func appendDeclaredStates(rawDQL string, result *plan.Result) { if result == nil || strings.TrimSpace(rawDQL) == "" { return } - seen := map[string]bool{} + seen := map[string]*plan.State{} for _, block := range extractSetBlocks(rawDQL) { - holder, kind, location, tail, ok := parseSetDeclarationBody(block.Body) + holder, kind, location, tail, tailOffset, ok := parseSetDeclarationBody(block.Body) if !ok { continue } - if kind == "view" || kind == "data_view" { - continue - } key := declaredStateKey(holder, kind, location) - if seen[key] { - continue + inName := location + if kind == "view" || kind == "data_view" { + if isAttachedSummaryState(result, holder) { + continue + } + // Keep parity with legacy translator: view declarations are addressed + // by declaration holder name (e.g. $Authorization(view/authorization)). + inName = holder } state := &plan.State{ Parameter: st.Parameter{ Name: holder, In: &st.Location{ Kind: st.Kind(kind), - Name: location, + Name: inName, }, }, } @@ -48,10 +55,124 @@ func appendDeclaredStates(rawDQL string, result *plan.Result) { required := true state.Required = &required } - applyDeclaredStateOptions(state, tail) + applyDeclaredStateOptions(state, tail, rawDQL, block.BodyOffset+tailOffset, &result.Diagnostics) + if existing := seen[key]; existing != nil { + mergeDeclaredState(existing, state) + continue + } result.States = append(result.States, state) + seen[key] = state + } + appendInferredPathStates(rawDQL, result, seen) +} + +func appendInferredPathStates(rawDQL string, result *plan.Result, seen map[string]*plan.State) { + if result == nil || strings.TrimSpace(rawDQL) == "" { + return + } + prepared := dqlpre.Prepare(rawDQL) + if prepared.Directives == nil || prepared.Directives.Route == nil { + return + } + for _, name := range extractRoutePathParams(prepared.Directives.Route.URI) { + key := declaredStateKey(name, string(st.KindPath), name) + if seen[key] != nil { + continue + } + result.States = append(result.States, &plan.State{ + Parameter: st.Parameter{ + Name: name, + In: st.NewPathLocation(name), + Schema: &st.Schema{ + DataType: "string", + Cardinality: st.One, + }, + }, + }) + seen[key] = result.States[len(result.States)-1] + } +} + +func mergeDeclaredState(dst, src *plan.State) { + if dst == nil || src == nil { + return + } + dst.EmitOutput = dst.EmitOutput || src.EmitOutput + dst.Async = dst.Async || src.Async + if dst.QuerySelector == "" { + dst.QuerySelector = src.QuerySelector + } + if dst.OutputDataType == "" { + dst.OutputDataType = src.OutputDataType + } + if dst.Tag == "" { + dst.Tag = src.Tag + } + if dst.Required == nil { + dst.Required = src.Required + } + if dst.Cacheable == nil { + dst.Cacheable = src.Cacheable + } + if dst.Schema == nil && src.Schema != nil { + schema := *src.Schema + dst.Schema = &schema + } + if dst.Schema != nil && src.Schema != nil { + if dst.Schema.DataType == "" { + dst.Schema.DataType = src.Schema.DataType + } + if dst.Schema.Cardinality == "" { + dst.Schema.Cardinality = src.Schema.Cardinality + } + } +} + +func extractRoutePathParams(uri string) []string { + uri = strings.TrimSpace(uri) + if uri == "" { + return nil + } + var result []string + seen := map[string]bool{} + for { + start := strings.IndexByte(uri, '{') + if start == -1 { + break + } + uri = uri[start+1:] + end := strings.IndexByte(uri, '}') + if end == -1 { + break + } + name := strings.TrimSpace(uri[:end]) + uri = uri[end+1:] + if name == "" { + continue + } + key := strings.ToLower(name) + if seen[key] { + continue + } seen[key] = true + result = append(result, name) } + return result +} + +func isAttachedSummaryState(result *plan.Result, holder string) bool { + if result == nil || strings.TrimSpace(holder) == "" { + return false + } + for _, item := range result.Views { + if item == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(item.SummaryName), strings.TrimSpace(holder)) { + return true + } + } + return false } func declaredStateKey(name, kind, in string) string { @@ -60,77 +181,161 @@ func declaredStateKey(name, kind, in string) string { strings.ToLower(strings.TrimSpace(in)) } -func applyDeclaredStateOptions(state *plan.State, tail string) { +func applyDeclaredStateOptions(state *plan.State, tail, dql string, baseOffset int, diags *[]*dqlshape.Diagnostic) { if state == nil || strings.TrimSpace(tail) == "" { return } cursor := newOptionCursor(tail) for cursor.next() { name, args := cursor.option() + optionOffset := baseOffset + cursor.start switch { case strings.EqualFold(name, "WithURI"): - if len(args) == 1 { - state.URI = trimQuote(args[0]) + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + state.URI = trimQuote(args[0]) + case strings.EqualFold(name, "WithTag"), strings.EqualFold(name, "Tag"): + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue } + state.Tag = trimQuote(args[0]) case strings.EqualFold(name, "Optional"): + if !expectStateArgs(state, name, args, 0, 0, dql, optionOffset, diags) { + continue + } required := false state.Required = &required case strings.EqualFold(name, "Required"): + if !expectStateArgs(state, name, args, 0, 0, dql, optionOffset, diags) { + continue + } required := true state.Required = &required case strings.EqualFold(name, "Cacheable"): - if len(args) == 1 { - if value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))); err == nil { - state.Cacheable = &value - } + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + appendStateOptionDiagnostic(state, name, fmt.Sprintf("invalid bool cacheable %q", args[0]), dql, optionOffset, diags) + continue } + state.Cacheable = &value case strings.EqualFold(name, "QuerySelector"): - if len(args) == 1 { - state.QuerySelector = trimQuote(args[0]) - if state.Cacheable == nil { - cacheable := false - state.Cacheable = &cacheable - } + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + state.QuerySelector = trimQuote(args[0]) + if state.Cacheable == nil { + cacheable := false + state.Cacheable = &cacheable } case strings.EqualFold(name, "WithPredicate"), strings.EqualFold(name, "Predicate"): + if !expectStateArgs(state, name, args, 1, -1, dql, optionOffset, diags) { + continue + } appendStatePredicate(state, args, false) case strings.EqualFold(name, "EnsurePredicate"): + if !expectStateArgs(state, name, args, 1, -1, dql, optionOffset, diags) { + continue + } appendStatePredicate(state, args, true) case strings.EqualFold(name, "When"): - if len(args) == 1 { - state.When = trimQuote(args[0]) + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue } + state.When = trimQuote(args[0]) case strings.EqualFold(name, "Scope"): - if len(args) == 1 { - state.Scope = trimQuote(args[0]) + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue } + state.Scope = trimQuote(args[0]) case strings.EqualFold(name, "WithType"): - if len(args) == 1 { - ensureStateSchema(state).DataType = trimQuote(args[0]) + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue } + ensureStateSchema(state).DataType = trimQuote(args[0]) case strings.EqualFold(name, "WithCodec"): - if len(args) >= 1 { - state.Output = &st.Codec{ - Name: trimQuote(args[0]), - Args: append([]string{}, trimQuotedArgs(args[1:])...), - } + if !expectStateArgs(state, name, args, 1, -1, dql, optionOffset, diags) { + continue + } + state.Output = &st.Codec{ + Name: trimQuote(args[0]), + Args: append([]string{}, trimQuotedArgs(args[1:])...), } case strings.EqualFold(name, "WithStatusCode"): - if len(args) == 1 { - if value, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))); err == nil { - state.ErrorStatusCode = value - } + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + value, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + appendStateOptionDiagnostic(state, name, fmt.Sprintf("invalid status code %q", args[0]), dql, optionOffset, diags) + continue } + state.ErrorStatusCode = value case strings.EqualFold(name, "WithErrorMessage"): - if len(args) == 1 { - state.ErrorMessage = trimQuote(args[0]) + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue } + state.ErrorMessage = trimQuote(args[0]) case strings.EqualFold(name, "Value"): - if len(args) == 1 { - state.Value = trimQuote(args[0]) + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + state.Value = trimQuote(args[0]) + case strings.EqualFold(name, "Embed"): + if !expectStateArgs(state, name, args, 0, 0, dql, optionOffset, diags) { + continue + } + if !strings.Contains(state.Tag, `anonymous:"true"`) { + if strings.TrimSpace(state.Tag) != "" { + state.Tag += " " + } + state.Tag += `anonymous:"true"` + } + case strings.EqualFold(name, "Cardinality"): + if !expectStateArgs(state, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + card := strings.ToLower(strings.TrimSpace(trimQuote(args[0]))) + switch card { + case "one": + ensureStateSchema(state).Cardinality = st.One + case "many": + ensureStateSchema(state).Cardinality = st.Many + default: + if state != nil && state.In != nil { + kind := strings.ToLower(state.KindString()) + if kind == "view" || kind == "data_view" { + // Declared views already validate cardinality with DQL-VIEW-CARDINALITY. + // Avoid duplicating that diagnostic on the shadow state projection. + continue + } + } + appendStateOptionDiagnostic(state, name, fmt.Sprintf("unsupported cardinality %q", args[0]), dql, optionOffset, diags) } case strings.EqualFold(name, "Async"): + if !expectStateArgs(state, name, args, 0, 0, dql, optionOffset, diags) { + continue + } state.Async = true + case strings.EqualFold(name, "Output"): + if !expectStateArgs(state, name, args, 0, 0, dql, optionOffset, diags) { + continue + } + state.EmitOutput = true + default: + if state != nil && state.In != nil { + kind := strings.ToLower(state.KindString()) + if kind == "view" || kind == "data_view" { + // View declarations carry many view-level options (e.g. Cardinality, + // WithURI, WithColumnType). Those are handled by declared-view parsing + // and should not emit state-option diagnostics. + continue + } + } + appendStateOptionDiagnostic(state, name, "unknown option", dql, optionOffset, diags) } } } @@ -222,6 +427,7 @@ func ensureStateSchema(state *plan.State) *st.Schema { type optionCursor struct { raw string cursor int + start int name string args []string } @@ -233,12 +439,14 @@ func newOptionCursor(raw string) *optionCursor { func (o *optionCursor) next() bool { o.name = "" o.args = nil + o.start = 0 for o.cursor < len(o.raw) && (o.raw[o.cursor] == ' ' || o.raw[o.cursor] == '\n' || o.raw[o.cursor] == '\t' || o.raw[o.cursor] == '\r') { o.cursor++ } if o.cursor >= len(o.raw) || o.raw[o.cursor] != '.' { return false } + o.start = o.cursor o.cursor++ start := o.cursor for o.cursor < len(o.raw) { @@ -305,3 +513,29 @@ func (o *optionCursor) next() bool { func (o *optionCursor) option() (string, []string) { return o.name, o.args } + +func expectStateArgs(state *plan.State, option string, args []string, min, max int, dql string, offset int, diags *[]*dqlshape.Diagnostic) bool { + if len(args) < min { + appendStateOptionDiagnostic(state, option, fmt.Sprintf("expected at least %d args, got %d", min, len(args)), dql, offset, diags) + return false + } + if max >= 0 && len(args) > max { + appendStateOptionDiagnostic(state, option, fmt.Sprintf("expected at most %d args, got %d", max, len(args)), dql, offset, diags) + return false + } + return true +} + +func appendStateOptionDiagnostic(state *plan.State, option, detail, dql string, offset int, diags *[]*dqlshape.Diagnostic) { + stateName := "" + if state != nil { + stateName = state.Name + } + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclOptionArgs, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("invalid %s declaration for state %q: %s", option, stateName, detail), + Hint: "check option name, arity and argument formatting", + Span: relationSpan(dql, offset), + }) +} diff --git a/repository/shape/compile/statedecl_test.go b/repository/shape/compile/statedecl_test.go index 33c9241a4..64d21ecb7 100644 --- a/repository/shape/compile/statedecl_test.go +++ b/repository/shape/compile/statedecl_test.go @@ -1,17 +1,20 @@ package compile import ( + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" "github.com/viant/datly/repository/shape/plan" ) func TestAppendDeclaredStates(t *testing.T) { dql := ` #set($_ = $Jwt(header/Authorization).WithCodec(JwtClaim).WithStatusCode(401)) -#set($_ = $Claims(header/Authorization).WithCodec(JwtClaim)) +#set($_ = $Claims(header/Authorization).WithCodec(JwtClaim).WithTag('json:"claims,omitempty"')) #set($_ = $Name(query/name).WithPredicate(0,'contains','sl','NAME').Optional()) #set($_ = $Fields<[]string>(query/fields).QuerySelector(site_list)) #set($_ = $Meta(output/summary)) @@ -37,6 +40,7 @@ SELECT id FROM SITE_LIST sl` require.NotNil(t, byName["Claims"]) assert.Equal(t, "string", byName["Claims"].Schema.DataType) assert.Equal(t, "*JwtClaims", byName["Claims"].OutputDataType) + assert.Equal(t, `json:"claims,omitempty"`, byName["Claims"].Tag) require.NotNil(t, byName["Name"]) assert.Equal(t, "query", byName["Name"].KindString()) @@ -77,3 +81,152 @@ SELECT id FROM USERS u` require.NotNil(t, result.States[0].Required) assert.True(t, *result.States[0].Required) } + +func TestAppendDeclaredStates_ViewDeclarationBecomesViewInput(t *testing.T) { + dql := ` +#define($_ = $Authorization(view/authorization).Required().WithStatusCode(403) /* SELECT Authorized FROM AUTH */) +SELECT id FROM USERS u` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.Len(t, result.States, 1) + state := result.States[0] + require.NotNil(t, state) + assert.Equal(t, "Authorization", state.Name) + assert.Equal(t, "view", state.KindString()) + assert.Equal(t, "Authorization", state.In.Name) + require.NotNil(t, state.Required) + assert.True(t, *state.Required) + assert.Equal(t, 403, state.ErrorStatusCode) +} + +func TestAppendDeclaredStates_SkipsSummaryAttachedViewDeclaration(t *testing.T) { + dql := ` +#define($_ = $ProductsMeta(view/products_meta) /* SELECT COUNT(1) CNT FROM ($View.products.SQL) t */) +SELECT vendor.*, products.* +FROM (SELECT * FROM VENDOR t) vendor +JOIN (SELECT * FROM PRODUCT t) products ON products.VENDOR_ID = vendor.ID` + result := &plan.Result{ + Views: []*plan.View{ + {Name: "vendor"}, + {Name: "products", SummaryName: "ProductsMeta"}, + }, + } + + appendDeclaredStates(dql, result) + + require.Empty(t, result.States) +} + +func TestAppendDeclaredStates_EmbedSetsAnonymousTag(t *testing.T) { + dql := ` +#set($_ = $Data(output/view).Embed()) +SELECT id FROM USERS u` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.Len(t, result.States, 1) + assert.Equal(t, "Data", result.States[0].Name) + assert.Equal(t, "output", result.States[0].KindString()) + assert.Contains(t, result.States[0].Tag, `anonymous:"true"`) +} + +func TestAppendDeclaredStates_OutputViewCardinalityIsParsed(t *testing.T) { + dql := ` +#define($_ = $Data(output/view).Cardinality('One').Embed()) +SELECT id FROM USERS u` + result := &plan.Result{} + + appendDeclaredStates(dql, result) + + require.Len(t, result.States, 1) + require.NotNil(t, result.States[0].Schema) + assert.Equal(t, "output", result.States[0].KindString()) + assert.Equal(t, "One", string(result.States[0].Schema.Cardinality)) + assert.Contains(t, result.States[0].Tag, `anonymous:"true"`) +} + +func TestAppendDeclaredStates_OutputOptionMarksStateForOutput(t *testing.T) { + dql := ` +#set($_ = $Foos(body/).Output().Tag('anonymous:"true"')) +SELECT * FROM FOOS` + result := &plan.Result{} + + appendDeclaredStates(dql, result) + + require.Len(t, result.States, 1) + assert.Equal(t, "Foos", result.States[0].Name) + assert.Equal(t, "body", result.States[0].KindString()) + assert.True(t, result.States[0].EmitOutput) +} + +func TestAppendDeclaredStates_DuplicateDeclarationMergesOutputMarker(t *testing.T) { + dql := ` +#set($_ = $Foos(body/).Cardinality('One').Tag('anonymous:"true"')) +#set($_ = $Foos(body/).Output().Tag('anonymous:"true"')) +SELECT * FROM FOOS` + result := &plan.Result{} + + appendDeclaredStates(dql, result) + + require.Len(t, result.States, 1) + assert.Equal(t, "Foos", result.States[0].Name) + assert.Equal(t, "body", result.States[0].KindString()) + assert.True(t, result.States[0].EmitOutput) + require.NotNil(t, result.States[0].Schema) + assert.Equal(t, "One", string(result.States[0].Schema.Cardinality)) +} + +func TestAppendDeclaredStates_InvalidOption_ReportsExactSpan(t *testing.T) { + dql := ` +#set($_ = $Auth(header/Authorization).Cacheable('x').UnknownFlag()) +SELECT id FROM USERS u` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.NotEmpty(t, result.Diagnostics) + require.GreaterOrEqual(t, len(result.Diagnostics), 2) + assert.Equal(t, dqldiag.CodeDeclOptionArgs, result.Diagnostics[0].Code) + assert.Equal(t, dqldiag.CodeDeclOptionArgs, result.Diagnostics[1].Code) + + cacheableOffset := strings.Index(dql, ".Cacheable") + require.GreaterOrEqual(t, cacheableOffset, 0) + cacheablePos := dqlpre.PointSpan(dql, cacheableOffset).Start + assert.Equal(t, cacheablePos.Line, result.Diagnostics[0].Span.Start.Line) + assert.Equal(t, cacheablePos.Char, result.Diagnostics[0].Span.Start.Char) + + unknownOffset := strings.Index(dql, ".UnknownFlag") + require.GreaterOrEqual(t, unknownOffset, 0) + unknownPos := dqlpre.PointSpan(dql, unknownOffset).Start + assert.Equal(t, unknownPos.Line, result.Diagnostics[1].Span.Start.Line) + assert.Equal(t, unknownPos.Char, result.Diagnostics[1].Span.Start.Char) +} + +func TestAppendDeclaredStates_InferPathStatesFromRouteDirective(t *testing.T) { + dql := ` +#setting($_ = $route('/v1/api/shape/dev/team/{teamID}', 'DELETE')) +DELETE FROM TEAM WHERE ID = ${teamID}` + result := &plan.Result{} + + appendDeclaredStates(dql, result) + + require.Len(t, result.States, 1) + assert.Equal(t, "teamID", result.States[0].Name) + assert.Equal(t, "path", result.States[0].KindString()) + assert.Equal(t, "teamID", result.States[0].In.Name) + require.NotNil(t, result.States[0].Schema) + assert.Equal(t, "string", result.States[0].Schema.DataType) +} + +func TestAppendDeclaredStates_ExplicitPathStateWinsOverInferredRouteParam(t *testing.T) { + dql := ` +#setting($_ = $route('/v1/api/shape/dev/vendors/{vendorID}', 'GET')) +#define($_ = $VendorID(path/vendorID)) +SELECT * FROM VENDOR WHERE ID = $VendorID` + result := &plan.Result{} + + appendDeclaredStates(dql, result) + + require.Len(t, result.States, 1) + assert.Equal(t, "VendorID", result.States[0].Name) + assert.Equal(t, "path", result.States[0].KindString()) + assert.Equal(t, "vendorID", result.States[0].In.Name) + assert.Equal(t, "int", result.States[0].Schema.DataType) +} diff --git a/repository/shape/compile/type_support.go b/repository/shape/compile/type_support.go index 44a8b9bca..585ae1e00 100644 --- a/repository/shape/compile/type_support.go +++ b/repository/shape/compile/type_support.go @@ -1,13 +1,22 @@ package compile import ( + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" "reflect" + "strconv" "strings" + "time" "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/repository/shape/typectx" "github.com/viant/x" + "github.com/viant/xunsafe" ) func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { @@ -15,9 +24,6 @@ func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { return } registry := source.EnsureTypeRegistry() - if registry == nil || len(registry.Keys()) == 0 { - return - } resolver := typectx.NewResolver(registry, result.TypeContext) rootTypeKey := resolveRootTypeKey(source, resolver, registry) existing := existingTypesByName(result.Types) @@ -26,11 +32,7 @@ func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { if item == nil { continue } - resolvedKey := resolveViewTypeKey(item, idx == 0, rootTypeKey, resolver, registry) - if resolvedKey == "" { - continue - } - resolvedType := registry.Lookup(resolvedKey) + resolvedType := resolveViewType(item, idx == 0, rootTypeKey, resolver, registry, result.TypeContext, source) if resolvedType == nil || resolvedType.Type == nil { continue } @@ -38,6 +40,15 @@ func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { if rType == nil { continue } + if isPlaceholderLinkedViewType(rType) { + continue + } + item.ElementType = rType + if strings.EqualFold(strings.TrimSpace(item.Cardinality), "many") { + item.FieldType = reflect.SliceOf(rType) + } else { + item.FieldType = rType + } typeExpr, typePkg := schemaTypeExpression(rType, result.TypeContext) if shouldSetSchemaType(item) && typeExpr != "" { item.SchemaType = typeExpr @@ -61,6 +72,273 @@ func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { } } +func isPlaceholderLinkedViewType(rType reflect.Type) bool { + rType = unwrapResolvedType(rType) + if rType == nil || rType.Kind() != reflect.Struct { + return false + } + hasScalars := false + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if !field.IsExported() { + continue + } + rawTag := string(field.Tag) + if strings.Contains(rawTag, `view:"`) || strings.Contains(rawTag, `on:"`) || strings.Contains(rawTag, `sqlx:"-"`) { + continue + } + hasScalars = true + if !isPlaceholderFieldName(field.Name, summaryTagName(field.Tag.Get("sqlx"))) { + return false + } + } + return hasScalars +} + +func isPlaceholderFieldName(fieldName, sqlxName string) bool { + return isPlaceholderName(fieldName) || isPlaceholderName(sqlxName) +} + +func isPlaceholderName(name string) bool { + name = strings.TrimSpace(name) + if name == "" { + return false + } + name = strings.TrimPrefix(strings.TrimPrefix(name, "name="), "*") + lower := strings.ToLower(strings.ReplaceAll(name, "_", "")) + if !strings.HasPrefix(lower, "col") || len(lower) == len("col") { + return false + } + _, err := strconv.Atoi(lower[len("col"):]) + return err == nil +} + +func applySummaryTypeSupport(result *plan.Result, source *shape.Source) { + if result == nil || source == nil { + return + } + registry := source.EnsureTypeRegistry() + if registry == nil { + return + } + resolver := typectx.NewResolver(registry, result.TypeContext) + existing := existingTypesByName(result.Types) + applySummaryTypeSupportWithResolver(result, source, resolver, registry, existing) +} + +func applySummaryTypeSupportWithResolver(result *plan.Result, source *shape.Source, resolver *typectx.Resolver, registry *x.Registry, existing map[string]bool) { + if result == nil || source == nil || registry == nil { + return + } + for _, item := range result.Views { + if item == nil { + continue + } + summaryName := strings.TrimSpace(item.SummaryName) + summarySQL := strings.TrimSpace(item.Summary) + if summaryName == "" || summarySQL == "" { + continue + } + typeName := summaryTypeName(summaryName) + if typeName == "" { + continue + } + queryNode, _, err := pipeline.ParseSelectWithDiagnostic(pipeline.NormalizeParserSQL(summarySQL)) + if err == nil && queryNode != nil { + _, elementType, _ := pipeline.InferProjectionType(queryNode) + elementType = unwrapResolvedType(elementType) + elementType = refineSummaryProjectionType(elementType, item, result.TypeContext, source) + if elementType != nil { + registerOpts := []x.Option{x.WithName(typeName), x.WithForceFlag()} + if ctx := result.TypeContext; ctx != nil { + if pkgPath := strings.TrimSpace(ctx.PackagePath); pkgPath != "" { + registerOpts = append(registerOpts, x.WithPkgPath(pkgPath)) + } + } + registry.Register(x.NewType(elementType, registerOpts...)) + appendResolvedType(result, elementType, typeName, existing, result.TypeContext) + continue + } + } + if key := resolveTypeKey(typeName, resolver, registry); key != "" { + if resolved := registry.Lookup(key); resolved != nil && resolved.Type != nil { + appendResolvedType(result, resolved.Type, typeName, existing, result.TypeContext) + continue + } + } + } +} + +func refineSummaryProjectionType(summaryType reflect.Type, item *plan.View, ctx *typectx.Context, source *shape.Source) reflect.Type { + summaryType = unwrapResolvedType(summaryType) + if summaryType == nil || summaryType.Kind() != reflect.Struct || item == nil { + return summaryType + } + ownerType := unwrapResolvedType(item.ElementType) + if ownerType == nil { + ownerType = unwrapResolvedType(item.FieldType) + } + if ownerType == nil || ownerType.Kind() != reflect.Struct { + ownerType = resolveSummaryOwnerType(item, ctx, source) + } + if ownerType == nil || ownerType.Kind() != reflect.Struct { + return summaryType + } + ownerFields := map[string]reflect.StructField{} + for i := 0; i < ownerType.NumField(); i++ { + field := ownerType.Field(i) + ownerFields[strings.ToUpper(strings.TrimSpace(field.Name))] = field + if sqlxName := summaryTagName(field.Tag.Get("sqlx")); sqlxName != "" { + ownerFields[strings.ToUpper(sqlxName)] = field + } + } + fields := make([]reflect.StructField, 0, summaryType.NumField()) + changed := false + for i := 0; i < summaryType.NumField(); i++ { + field := summaryType.Field(i) + if ownerField, ok := ownerFields[strings.ToUpper(summaryLookupName(field))]; ok && ownerField.Type != nil && ownerField.Type != field.Type { + field.Type = ownerField.Type + changed = true + } + fields = append(fields, field) + } + if !changed { + return summaryType + } + return reflect.StructOf(fields) +} + +func resolveSummaryOwnerType(item *plan.View, ctx *typectx.Context, source *shape.Source) reflect.Type { + if item == nil { + return nil + } + for _, candidate := range summaryOwnerTypeCandidates(item) { + if linked := lookupLinkedType(candidate, ctx, source); linked != nil { + linked = unwrapResolvedType(linked) + if linked != nil && linked.Kind() == reflect.Struct { + return linked + } + } + } + return nil +} + +func summaryOwnerTypeCandidates(item *plan.View) []string { + if item == nil { + return nil + } + result := make([]string, 0, 6) + seen := map[string]bool{} + appendCandidate := func(value string) { + value = strings.TrimSpace(value) + if value == "" || seen[value] { + return + } + seen[value] = true + result = append(result, value) + } + if item.Declaration != nil { + appendCandidate(item.Declaration.DataType) + appendCandidate(item.Declaration.Of) + } + appendCandidate(item.SchemaType) + name := toExportedTypeName(item.Name) + if name != "" { + appendCandidate(name + "View") + appendCandidate(name) + } + return result +} + +func summaryLookupName(field reflect.StructField) string { + if sqlxName := summaryTagName(field.Tag.Get("sqlx")); sqlxName != "" { + return sqlxName + } + return strings.TrimSpace(field.Name) +} + +func summaryTagName(tag string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + if strings.HasPrefix(tag, "name=") { + tag = strings.TrimPrefix(tag, "name=") + } + if idx := strings.Index(tag, ","); idx != -1 { + tag = tag[:idx] + } + return strings.TrimSpace(tag) +} + +func appendResolvedType(result *plan.Result, rType reflect.Type, typeName string, existing map[string]bool, ctx *typectx.Context) { + rType = unwrapResolvedType(rType) + typeName = strings.TrimSpace(typeName) + if result == nil || rType == nil || typeName == "" { + return + } + key := strings.ToLower(typeName) + if existing[key] { + return + } + typeExpr, typePkg := summarySchemaTypeExpression(typeName, ctx) + result.Types = append(result.Types, &plan.Type{ + Name: typeName, + DataType: typeExpr, + Cardinality: string(planStateOne()), + Package: typePkg, + ModulePath: strings.TrimSpace(rType.PkgPath()), + }) + existing[key] = true +} + +func summaryTypeName(summaryName string) string { + summaryName = strings.TrimSpace(summaryName) + if summaryName == "" { + return "" + } + if strings.HasSuffix(summaryName, "View") { + return summaryName + } + return toExportedTypeName(summaryName) + "View" +} + +func summarySchemaTypeExpression(typeName string, ctx *typectx.Context) (string, string) { + typeName = strings.TrimSpace(typeName) + if typeName == "" { + return "", "" + } + if ctx != nil { + if pkgAlias := strings.TrimSpace(ctx.PackageName); pkgAlias != "" { + return "*" + pkgAlias + "." + typeName, pkgAlias + } + if pkgPath := strings.TrimSpace(ctx.PackagePath); pkgPath != "" { + return "*" + packageAlias(pkgPath, ctx) + "." + typeName, packageAlias(pkgPath, ctx) + } + } + return "*" + typeName, "" +} + +func planStateOne() string { + return "one" +} + +func resolveViewType(item *plan.View, root bool, rootTypeKey string, resolver *typectx.Resolver, registry *x.Registry, ctx *typectx.Context, source *shape.Source) *x.Type { + for _, candidate := range viewTypeCandidates(item, root, rootTypeKey) { + if key := resolveTypeKey(candidate, resolver, registry); key != "" { + if registry != nil { + if resolved := registry.Lookup(key); resolved != nil && resolved.Type != nil { + return resolved + } + } + } + if linked := lookupLinkedType(candidate, ctx, source); linked != nil { + return x.NewType(linked) + } + } + return nil +} + func resolveRootTypeKey(source *shape.Source, resolver *typectx.Resolver, registry *x.Registry) string { if source == nil || registry == nil { return "" @@ -75,9 +353,9 @@ func resolveRootTypeKey(source *shape.Source, resolver *typectx.Resolver, regist return resolveTypeKey(x.NewType(rType).Key(), resolver, registry) } -func resolveViewTypeKey(item *plan.View, root bool, rootTypeKey string, resolver *typectx.Resolver, registry *x.Registry) string { - if item == nil || registry == nil { - return "" +func viewTypeCandidates(item *plan.View, root bool, rootTypeKey string) []string { + if item == nil { + return nil } candidates := make([]string, 0, 8) seen := map[string]bool{} @@ -106,12 +384,7 @@ func resolveViewTypeKey(item *plan.View, root bool, rootTypeKey string, resolver appendCandidate(name + "View") appendCandidate(name) } - for _, candidate := range candidates { - if key := resolveTypeKey(candidate, resolver, registry); key != "" { - return key - } - } - return "" + return candidates } func resolveTypeKey(typeExpr string, resolver *typectx.Resolver, registry *x.Registry) string { @@ -236,3 +509,297 @@ func unwrapResolvedType(rType reflect.Type) reflect.Type { } return nil } + +func lookupLinkedType(typeExpr string, ctx *typectx.Context, source *shape.Source) reflect.Type { + base := normalizeTypeLookupKey(typeExpr) + if base == "" { + return nil + } + if pkg, name, ok := splitQualifiedType(base); ok { + if fullPkg := packagePathForAlias(pkg, ctx); fullPkg != "" { + if linked := xunsafe.LookupType(fullPkg + "/" + name); linked != nil { + return linked + } + if linked := lookupASTType(fullPkg, name, ctx, source); linked != nil { + return linked + } + } + return nil + } + if ctx != nil && strings.TrimSpace(ctx.PackagePath) != "" { + if linked := xunsafe.LookupType(strings.TrimSpace(ctx.PackagePath) + "/" + base); linked != nil { + return linked + } + if linked := lookupASTType(strings.TrimSpace(ctx.PackagePath), base, ctx, source); linked != nil { + return linked + } + } + return nil +} + +func splitQualifiedType(value string) (string, string, bool) { + index := strings.Index(value, ".") + if index <= 0 || index+1 >= len(value) { + return "", "", false + } + return strings.TrimSpace(value[:index]), strings.TrimSpace(value[index+1:]), true +} + +func packagePathForAlias(alias string, ctx *typectx.Context) string { + alias = strings.TrimSpace(alias) + if alias == "" || ctx == nil { + return "" + } + for _, item := range ctx.Imports { + if strings.TrimSpace(item.Alias) == alias { + return strings.TrimSpace(item.Package) + } + } + if strings.TrimSpace(ctx.PackageName) == alias { + return strings.TrimSpace(ctx.PackagePath) + } + return "" +} + +func lookupASTType(pkgPath, typeName string, ctx *typectx.Context, source *shape.Source) reflect.Type { + pkgDir := resolveTypePackageDir(pkgPath, ctx, source) + if pkgDir == "" { + return nil + } + return parseNamedStructType(pkgDir, typeName) +} + +func resolveTypePackageDir(pkgPath string, ctx *typectx.Context, source *shape.Source) string { + if ctx == nil { + return "" + } + moduleRoot := nearestModuleRoot(source) + if moduleRoot == "" { + if strings.TrimSpace(ctx.PackagePath) == strings.TrimSpace(pkgPath) { + if dir := strings.TrimSpace(ctx.PackageDir); dir != "" { + if filepath.IsAbs(dir) { + return dir + } + } + } + return "" + } + modulePath := detectModulePath(moduleRoot) + if modulePath != "" { + if rel, ok := packagePathRelative(modulePath, pkgPath); ok { + if rel == "" { + return moduleRoot + } + return filepath.Join(moduleRoot, filepath.FromSlash(rel)) + } + } + if strings.TrimSpace(ctx.PackagePath) == strings.TrimSpace(pkgPath) { + if dir := strings.TrimSpace(ctx.PackageDir); dir != "" { + if filepath.IsAbs(dir) { + return dir + } + return filepath.Join(moduleRoot, filepath.FromSlash(dir)) + } + } + return "" +} + +func packageNameForPath(pkgPath string, ctx *typectx.Context) string { + if ctx != nil && strings.TrimSpace(ctx.PackagePath) == strings.TrimSpace(pkgPath) && strings.TrimSpace(ctx.PackageName) != "" { + return strings.TrimSpace(ctx.PackageName) + } + if index := strings.LastIndex(strings.TrimSpace(pkgPath), "/"); index != -1 { + return strings.TrimSpace(pkgPath[index+1:]) + } + return strings.TrimSpace(pkgPath) +} + +func nearestModuleRoot(source *shape.Source) string { + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + current := filepath.Dir(strings.TrimSpace(source.Path)) + for current != "" && current != string(filepath.Separator) && current != "." { + if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { + return current + } + parent := filepath.Dir(current) + if parent == current { + break + } + current = parent + } + return "" +} + +func parseNamedStructType(pkgDir, typeName string) reflect.Type { + fset := token.NewFileSet() + pkgs, err := parser.ParseDir(fset, pkgDir, nil, parser.ParseComments) + if err != nil || len(pkgs) == 0 { + return nil + } + specs := map[string]*ast.TypeSpec{} + for _, pkg := range pkgs { + for _, file := range pkg.Files { + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok || typeSpec.Name == nil { + continue + } + specs[typeSpec.Name.Name] = typeSpec + } + } + } + } + cache := map[string]reflect.Type{} + inProgress := map[string]bool{} + var buildNamed func(name string) reflect.Type + var buildExpr func(expr ast.Expr) reflect.Type + + buildNamed = func(name string) reflect.Type { + if cached, ok := cache[name]; ok { + return cached + } + if inProgress[name] { + return reflect.TypeOf(new(interface{})).Elem() + } + spec := specs[name] + if spec == nil { + return nil + } + inProgress[name] = true + rType := buildExpr(spec.Type) + delete(inProgress, name) + if rType != nil { + cache[name] = rType + } + return rType + } + + buildExpr = func(expr ast.Expr) reflect.Type { + switch actual := expr.(type) { + case *ast.Ident: + switch actual.Name { + case "string": + return reflect.TypeOf("") + case "bool": + return reflect.TypeOf(true) + case "int": + return reflect.TypeOf(int(0)) + case "int8": + return reflect.TypeOf(int8(0)) + case "int16": + return reflect.TypeOf(int16(0)) + case "int32": + return reflect.TypeOf(int32(0)) + case "int64": + return reflect.TypeOf(int64(0)) + case "uint": + return reflect.TypeOf(uint(0)) + case "uint8": + return reflect.TypeOf(uint8(0)) + case "uint16": + return reflect.TypeOf(uint16(0)) + case "uint32": + return reflect.TypeOf(uint32(0)) + case "uint64": + return reflect.TypeOf(uint64(0)) + case "float32": + return reflect.TypeOf(float32(0)) + case "float64": + return reflect.TypeOf(float64(0)) + case "interface{}", "any": + return reflect.TypeOf(new(interface{})).Elem() + default: + return buildNamed(actual.Name) + } + case *ast.StarExpr: + if inner := buildExpr(actual.X); inner != nil { + return reflect.PtrTo(inner) + } + case *ast.ArrayType: + if actual.Len == nil { + if inner := buildExpr(actual.Elt); inner != nil { + return reflect.SliceOf(inner) + } + } + case *ast.MapType: + key := buildExpr(actual.Key) + value := buildExpr(actual.Value) + if key != nil && value != nil { + return reflect.MapOf(key, value) + } + case *ast.InterfaceType: + return reflect.TypeOf(new(interface{})).Elem() + case *ast.SelectorExpr: + if ident, ok := actual.X.(*ast.Ident); ok { + if ident.Name == "time" && actual.Sel != nil && actual.Sel.Name == "Time" { + return reflect.TypeOf(time.Time{}) + } + } + case *ast.StructType: + fields := make([]reflect.StructField, 0, len(actual.Fields.List)) + seen := map[string]bool{} + for _, field := range actual.Fields.List { + if field == nil { + continue + } + fieldType := buildExpr(field.Type) + if fieldType == nil { + continue + } + tag := reflect.StructTag("") + if field.Tag != nil { + tag = reflect.StructTag(strings.Trim(field.Tag.Value, "`")) + } + if len(field.Names) == 0 { + if name := exportedEmbeddedFieldName(field.Type); name != "" { + if seen[name] { + continue + } + seen[name] = true + fields = append(fields, reflect.StructField{Name: name, Type: fieldType, Tag: tag, Anonymous: true}) + } + continue + } + for _, name := range field.Names { + if name == nil || !name.IsExported() { + continue + } + if seen[name.Name] { + continue + } + seen[name.Name] = true + fields = append(fields, reflect.StructField{Name: name.Name, Type: fieldType, Tag: tag}) + } + } + if len(fields) > 0 { + return reflect.StructOf(fields) + } + } + return nil + } + + return buildNamed(typeName) +} + +func exportedEmbeddedFieldName(expr ast.Expr) string { + switch actual := expr.(type) { + case *ast.Ident: + if actual.IsExported() { + return actual.Name + } + case *ast.SelectorExpr: + if actual.Sel != nil && actual.Sel.IsExported() { + return actual.Sel.Name + } + case *ast.StarExpr: + return exportedEmbeddedFieldName(actual.X) + } + return "" +} diff --git a/repository/shape/compile/type_support_summary_test.go b/repository/shape/compile/type_support_summary_test.go new file mode 100644 index 000000000..158056e48 --- /dev/null +++ b/repository/shape/compile/type_support_summary_test.go @@ -0,0 +1,241 @@ +package compile + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/x" +) + +func TestApplyLinkedTypeSupport_RegistersSummaryTypes(t *testing.T) { + source := &shape.Source{ + TypeRegistry: x.NewRegistry(), + } + result := &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "meta_nested", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/meta_nested", + }, + Views: []*plan.View{ + { + Name: "vendor", + SummaryName: "Meta", + Summary: "SELECT COUNT(*) AS CNT, 1 AS PAGE_CNT FROM ($View.vendor.SQL) t", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Cardinality: "many", + }, + { + Name: "products", + SummaryName: "ProductsMeta", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) t GROUP BY VENDOR_ID", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Cardinality: "many", + }, + }, + } + + applySummaryTypeSupport(result, source) + applyLinkedTypeSupport(result, source) + + registry := source.EnsureTypeRegistry() + require.NotNil(t, registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/meta_nested.MetaView")) + require.NotNil(t, registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/meta_nested.ProductsMetaView")) + + var names []string + for _, item := range result.Types { + if item != nil { + names = append(names, item.Name) + } + } + assert.Contains(t, names, "MetaView") + assert.Contains(t, names, "ProductsMetaView") +} + +func TestApplyLinkedTypeSupport_SkipsPlaceholderLinkedViewTypes(t *testing.T) { + type placeholderVendorView struct { + Col1 string `sqlx:"name=col_1"` + Col2 string `sqlx:"name=col_2"` + } + + registry := x.NewRegistry() + registry.Register(x.NewType( + reflect.TypeOf(placeholderVendorView{}), + x.WithName("VendorView"), + x.WithPkgPath("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary"), + )) + + source := &shape.Source{TypeRegistry: registry} + result := &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "multi_summary", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary", + }, + Views: []*plan.View{ + { + Name: "vendor", + SchemaType: "*VendorView", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Cardinality: "many", + }, + }, + } + + applyLinkedTypeSupport(result, source) + + assert.Equal(t, reflect.TypeOf(map[string]interface{}{}), result.Views[0].ElementType) + assert.Equal(t, reflect.TypeOf([]map[string]interface{}{}), result.Views[0].FieldType) +} + +func TestApplySummaryTypeSupport_RegistersSummaryTypesWithoutLinkedViews(t *testing.T) { + source := &shape.Source{ + TypeRegistry: x.NewRegistry(), + } + result := &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "meta_nested", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/meta_nested", + }, + Views: []*plan.View{ + { + Name: "vendor", + SummaryName: "Meta", + Summary: "SELECT COUNT(*) AS CNT, 1 AS PAGE_CNT FROM ($View.vendor.SQL) t", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Cardinality: "many", + }, + { + Name: "products", + SummaryName: "ProductsMeta", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) t GROUP BY VENDOR_ID", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Cardinality: "many", + }, + }, + } + + applySummaryTypeSupport(result, source) + + registry := source.EnsureTypeRegistry() + require.NotNil(t, registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/meta_nested.MetaView")) + require.NotNil(t, registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/meta_nested.ProductsMetaView")) +} + +func TestApplySummaryTypeSupport_PreservesOwnerColumnTypes(t *testing.T) { + type productView struct { + VendorId *int `sqlx:"VENDOR_ID"` + } + + source := &shape.Source{TypeRegistry: x.NewRegistry()} + result := &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "multi_summary", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary", + }, + Views: []*plan.View{ + { + Name: "products", + SummaryName: "ProductsMeta", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) t GROUP BY VENDOR_ID", + FieldType: reflect.TypeOf([]productView{}), + ElementType: reflect.TypeOf(productView{}), + Cardinality: "many", + }, + }, + } + + applySummaryTypeSupport(result, source) + + registry := source.EnsureTypeRegistry() + registered := registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary.ProductsMetaView") + require.NotNil(t, registered) + require.NotNil(t, registered.Type) + field, ok := registered.Type.FieldByName("VendorId") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) +} + +func TestApplySummaryTypeSupport_InfersNullableComputedSummaryColumns(t *testing.T) { + source := &shape.Source{TypeRegistry: x.NewRegistry()} + result := &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "multi_summary", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary", + }, + Views: []*plan.View{ + { + Name: "vendor", + SummaryName: "Meta", + Summary: "SELECT CAST(1 + (COUNT(1) / 25) AS SIGNED) AS PAGE_CNT, COUNT(1) AS CNT FROM ($View.vendor.SQL) t", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Cardinality: "many", + }, + }, + } + + applySummaryTypeSupport(result, source) + + registry := source.EnsureTypeRegistry() + registered := registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary.MetaView") + require.NotNil(t, registered) + require.NotNil(t, registered.Type) + pageCnt, ok := registered.Type.FieldByName("PageCnt") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), pageCnt.Type) + cnt, ok := registered.Type.FieldByName("Cnt") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf(int(0)), cnt.Type) +} + +func TestApplySummaryTypeSupport_OverridesStaleRegisteredSummaryType(t *testing.T) { + type staleProductsMetaView struct { + VendorId string `sqlx:"VENDOR_ID"` + } + type productView struct { + VendorId *int `sqlx:"VENDOR_ID"` + } + + registry := x.NewRegistry() + registry.Register(x.NewType( + reflect.TypeOf(staleProductsMetaView{}), + x.WithName("ProductsMetaView"), + x.WithPkgPath("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary"), + )) + + source := &shape.Source{TypeRegistry: registry} + result := &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "multi_summary", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary", + }, + Views: []*plan.View{ + { + Name: "products", + SummaryName: "ProductsMeta", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) t GROUP BY VENDOR_ID", + FieldType: reflect.TypeOf([]productView{}), + ElementType: reflect.TypeOf(productView{}), + Cardinality: "many", + }, + }, + } + + applySummaryTypeSupport(result, source) + + registered := registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary.ProductsMetaView") + require.NotNil(t, registered) + require.NotNil(t, registered.Type) + field, ok := registered.Type.FieldByName("VendorId") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) +} diff --git a/repository/shape/compile/typectx_defaults.go b/repository/shape/compile/typectx_defaults.go index 5561d36a3..89d4492ec 100644 --- a/repository/shape/compile/typectx_defaults.go +++ b/repository/shape/compile/typectx_defaults.go @@ -5,6 +5,7 @@ import ( "path" "path/filepath" "strings" + "unicode" "github.com/viant/datly/repository/shape" "github.com/viant/datly/repository/shape/typectx" @@ -13,6 +14,7 @@ import ( func applyTypeContextDefaults(ctx *typectx.Context, source *shape.Source, opts *shape.CompileOptions, layout compilePathLayout) *typectx.Context { ret := cloneTypeContext(ctx) + ret = hydrateExplicitTypeContext(ret, source, layout) if shouldInferTypeContext(opts) { ret = mergeTypeContext(ret, inferDatlyGenTypeContext(source, layout)) } @@ -34,6 +36,75 @@ func applyTypeContextDefaults(ctx *typectx.Context, source *shape.Source, opts * return normalizeTypeContext(ret) } +func hydrateExplicitTypeContext(ctx *typectx.Context, source *shape.Source, layout compilePathLayout) *typectx.Context { + if ctx == nil { + return nil + } + if strings.TrimSpace(ctx.PackagePath) == "" && strings.TrimSpace(ctx.DefaultPackage) != "" { + ctx.PackagePath = strings.TrimSpace(ctx.DefaultPackage) + } + if strings.TrimSpace(ctx.PackageName) == "" { + base := path.Base(strings.TrimSpace(ctx.PackagePath)) + if base == "." || base == "/" || base == "" { + base = path.Base(strings.TrimSpace(ctx.DefaultPackage)) + } + ctx.PackageName = sanitizePackageName(base) + } + if strings.TrimSpace(ctx.PackageDir) == "" { + parsed, ok := parseSourceLayout(source, layout) + if ok { + modulePath := detectModulePath(parsed.projectRoot) + pkgPath := strings.TrimSpace(ctx.PackagePath) + if modulePath != "" && pkgPath != "" { + if rel, ok := packagePathRelative(modulePath, pkgPath); ok { + ctx.PackageDir = rel + } + } + } + } + return ctx +} + +func packagePathRelative(modulePath, packagePath string) (string, bool) { + modulePath = strings.Trim(strings.TrimSpace(modulePath), "/") + packagePath = strings.Trim(strings.TrimSpace(packagePath), "/") + if modulePath == "" || packagePath == "" { + return "", false + } + if packagePath == modulePath { + return "", true + } + prefix := modulePath + "/" + if !strings.HasPrefix(packagePath, prefix) { + return "", false + } + return strings.TrimPrefix(packagePath, prefix), true +} + +func sanitizePackageName(name string) string { + name = strings.TrimSpace(strings.ToLower(name)) + if name == "" { + return "" + } + var out strings.Builder + for _, r := range name { + switch { + case r == '_' || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9'): + out.WriteRune(r) + case r == '-' || unicode.IsSpace(r): + out.WriteRune('_') + } + } + result := strings.Trim(out.String(), "_") + if result == "" { + return "" + } + if result[0] >= '0' && result[0] <= '9' { + return "p" + result + } + return result +} + func shouldInferTypeContext(opts *shape.CompileOptions) bool { if opts == nil || opts.InferTypeContext == nil { return true diff --git a/repository/shape/compile/typectx_defaults_test.go b/repository/shape/compile/typectx_defaults_test.go index aa3d01d8c..d396fbfc3 100644 --- a/repository/shape/compile/typectx_defaults_test.go +++ b/repository/shape/compile/typectx_defaults_test.go @@ -43,6 +43,17 @@ func TestApplyTypeContextDefaults_Matrix(t *testing.T) { require.Equal(t, "github.com/acme/manual", got.DefaultPackage) }) + t.Run("default package hydrates package path dir and sanitized name", func(t *testing.T) { + input := &typectx.Context{ + DefaultPackage: "github.vianttech.com/viant/platform/pkg/dev/events-one-one", + } + got := applyTypeContextDefaults(input, source, nil, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/dev/events-one-one", got.PackageDir) + require.Equal(t, "events_one_one", got.PackageName) + require.Equal(t, "github.vianttech.com/viant/platform/pkg/dev/events-one-one", got.PackagePath) + }) + t.Run("compile override wins over both", func(t *testing.T) { input := &typectx.Context{ PackageDir: "pkg/manual", diff --git a/repository/shape/compile/viewdecl.go b/repository/shape/compile/viewdecl.go index 9e6c14c88..e60fa78c6 100644 --- a/repository/shape/compile/viewdecl.go +++ b/repository/shape/compile/viewdecl.go @@ -12,30 +12,36 @@ import ( ) type declaredView struct { - Name string - SQL string - URI string - Connector string - Cardinality string - Tag string - Codec string - CodecArgs []string - HandlerName string - HandlerArgs []string - StatusCode *int - ErrorMessage string - QuerySelector string - CacheRef string - Limit *int - Cacheable *bool - When string - Scope string - DataType string - Of string - Value string - Async bool - Output bool - Predicates []declaredPredicate + Name string + VirtualSummary bool + SQL string + URI string + Connector string + Cardinality string + Required bool + CardinalitySet bool + Tag string + TypeName string + Dest string + Codec string + CodecArgs []string + HandlerName string + HandlerArgs []string + StatusCode *int + ErrorMessage string + QuerySelector string + CacheRef string + Limit *int + Cacheable *bool + When string + Scope string + DataType string + Of string + Value string + Async bool + Output bool + Predicates []declaredPredicate + ColumnsConfig map[string]*declaredColumnConfig } type declaredPredicate struct { @@ -45,6 +51,12 @@ type declaredPredicate struct { Arguments []string } +type declaredColumnConfig struct { + DataType string + Tag string + Groupable *bool +} + const ( vdWhitespaceToken = iota vdSetToken @@ -74,14 +86,14 @@ func extractDeclaredViews(dql string) ([]*declaredView, []*dqlshape.Diagnostic) var views []*declaredView var diags []*dqlshape.Diagnostic for _, block := range extractSetBlocks(dql) { - holder, kind, location, tail, ok := parseSetDeclarationBody(block.Body) + holder, kind, location, tail, tailOffset, ok := parseSetDeclarationBody(block.Body) if !ok { continue } - if kind != "view" && kind != "data_view" { + if kind != "view" && kind != "data_view" && !isOutputSummaryDeclaration(kind, location) { continue } - sqlText := extractDeclarationSQL(tail) + sqlText, errorStatusCode := extractDeclarationSQLWithStatus(tail) if sqlText == "" { diags = append(diags, &dqlshape.Diagnostic{ Code: dqldiag.CodeViewMissingSQL, @@ -92,16 +104,28 @@ func extractDeclaredViews(dql string) ([]*declaredView, []*dqlshape.Diagnostic) }) continue } - name := pipeline.SanitizeName(location) + name := pipeline.SanitizeName(holder) if name == "" { - name = pipeline.SanitizeName(holder) + name = pipeline.SanitizeName(location) } if name == "" { continue } - view := &declaredView{Name: name, SQL: strings.TrimSpace(sqlText)} - applyDeclaredViewOptions(view, tail, dql, block.Offset, &diags) + view := &declaredView{ + Name: name, + SQL: strings.TrimSpace(sqlText), + VirtualSummary: isOutputSummaryDeclaration(kind, location), + } + if errorStatusCode != nil { + view.StatusCode = errorStatusCode + } + applyDeclaredViewOptions(view, tail, dql, block.BodyOffset+tailOffset, &diags) views = append(views, view) } return views, diags } + +func isOutputSummaryDeclaration(kind, location string) bool { + return strings.EqualFold(strings.TrimSpace(kind), "output") && + strings.EqualFold(strings.TrimSpace(location), "summary") +} diff --git a/repository/shape/compile/viewdecl_append.go b/repository/shape/compile/viewdecl_append.go index 6ae2fc663..0f91303fc 100644 --- a/repository/shape/compile/viewdecl_append.go +++ b/repository/shape/compile/viewdecl_append.go @@ -13,6 +13,7 @@ func appendDeclaredViews(rawDQL string, result *plan.Result) { if result == nil { return } + appendRootOutputViewDeclaration(rawDQL, result) declared, diags := extractDeclaredViews(rawDQL) if len(diags) > 0 { result.Diagnostics = append(result.Diagnostics, diags...) @@ -21,9 +22,17 @@ func appendDeclaredViews(rawDQL string, result *plan.Result) { if item == nil || strings.TrimSpace(item.Name) == "" || strings.TrimSpace(item.SQL) == "" { continue } + if item.VirtualSummary { + if root := lookupRootView(result); root != nil && strings.TrimSpace(root.Summary) == "" { + root.Summary = strings.TrimSpace(item.SQL) + root.SummaryName = strings.TrimSpace(item.Name) + } + continue + } if parent := lookupSummaryParentView(result, item.SQL); parent != nil { if strings.TrimSpace(parent.Summary) == "" { - parent.Summary = strings.TrimSpace(item.SQL) + parent.Summary = normalizeSummarySQLForParent(parent, item.SQL) + parent.SummaryName = strings.TrimSpace(item.Name) } continue } @@ -43,20 +52,23 @@ func appendDeclaredViews(rawDQL string, result *plan.Result) { ElementType: reflect.TypeOf(map[string]interface{}{}), Declaration: buildViewDeclaration(item), } + if item.Required && !item.CardinalitySet { + view.Cardinality = "one" + } if item.Cardinality != "" { view.Cardinality = item.Cardinality } if queryNode, err := sqlparser.ParseQuery(item.SQL); err == nil && queryNode != nil { if inferredName, inferredTable, err := pipeline.InferRoot(queryNode, item.Name); err == nil { - view.Name = inferredName - view.Holder = inferredName - view.Path = inferredName - view.Table = inferredTable + _ = inferredName + if strings.TrimSpace(inferredTable) != "" { + view.Table = inferredTable + } } if fType, eType, card := pipeline.InferProjectionType(queryNode); fType != nil && eType != nil { view.FieldType = fType view.ElementType = eType - if item.Cardinality == "" { + if item.Cardinality == "" && !(item.Required && !item.CardinalitySet) { view.Cardinality = card } } @@ -66,10 +78,158 @@ func appendDeclaredViews(rawDQL string, result *plan.Result) { } } +func appendRootOutputViewDeclaration(rawDQL string, result *plan.Result) { + if result == nil { + return + } + root := lookupRootView(result) + if root == nil { + return + } + for _, block := range extractSetBlocks(rawDQL) { + _, kind, location, tail, tailOffset, ok := parseSetDeclarationBody(block.Body) + if !ok || !strings.EqualFold(strings.TrimSpace(kind), "output") || !strings.EqualFold(strings.TrimSpace(location), "view") { + continue + } + view := &declaredView{} + applyDeclaredViewOptions(view, tail, rawDQL, block.BodyOffset+tailOffset, &result.Diagnostics) + mergeViewDeclaration(root, buildViewDeclaration(view)) + if view.Required && !view.CardinalitySet { + root.Cardinality = "one" + } + if view.Cardinality != "" { + root.Cardinality = view.Cardinality + } + } +} + +func mergeViewDeclaration(target *plan.View, declared *plan.ViewDeclaration) { + if target == nil || declared == nil { + return + } + if target.Declaration == nil { + target.Declaration = declared + return + } + dst := target.Declaration + if strings.TrimSpace(dst.Tag) == "" { + dst.Tag = declared.Tag + } + if strings.TrimSpace(dst.TypeName) == "" { + dst.TypeName = declared.TypeName + } + if strings.TrimSpace(dst.Dest) == "" { + dst.Dest = declared.Dest + } + if strings.TrimSpace(dst.Codec) == "" { + dst.Codec = declared.Codec + if len(dst.CodecArgs) == 0 { + dst.CodecArgs = append([]string{}, declared.CodecArgs...) + } + } + if strings.TrimSpace(dst.HandlerName) == "" { + dst.HandlerName = declared.HandlerName + if len(dst.HandlerArgs) == 0 { + dst.HandlerArgs = append([]string{}, declared.HandlerArgs...) + } + } + if dst.StatusCode == nil { + dst.StatusCode = declared.StatusCode + } + if strings.TrimSpace(dst.ErrorMessage) == "" { + dst.ErrorMessage = declared.ErrorMessage + } + if strings.TrimSpace(dst.QuerySelector) == "" { + dst.QuerySelector = declared.QuerySelector + } + if strings.TrimSpace(dst.CacheRef) == "" { + dst.CacheRef = declared.CacheRef + } + if dst.Limit == nil { + dst.Limit = declared.Limit + } + if dst.Cacheable == nil { + dst.Cacheable = declared.Cacheable + } + if strings.TrimSpace(dst.When) == "" { + dst.When = declared.When + } + if strings.TrimSpace(dst.Scope) == "" { + dst.Scope = declared.Scope + } + if strings.TrimSpace(dst.DataType) == "" { + dst.DataType = declared.DataType + } + if strings.TrimSpace(dst.Of) == "" { + dst.Of = declared.Of + } + if strings.TrimSpace(dst.Value) == "" { + dst.Value = declared.Value + } + dst.Async = dst.Async || declared.Async + dst.Output = dst.Output || declared.Output + if len(dst.Predicates) == 0 && len(declared.Predicates) > 0 { + dst.Predicates = append([]*plan.ViewPredicate{}, declared.Predicates...) + } + if len(declared.ColumnsConfig) > 0 { + if dst.ColumnsConfig == nil { + dst.ColumnsConfig = map[string]*plan.ViewColumnConfig{} + } + for name, cfg := range declared.ColumnsConfig { + if strings.TrimSpace(name) == "" || cfg == nil { + continue + } + dst.ColumnsConfig[name] = &plan.ViewColumnConfig{ + DataType: strings.TrimSpace(cfg.DataType), + Tag: strings.TrimSpace(cfg.Tag), + Groupable: cloneBoolPtr(cfg.Groupable), + } + } + } +} + +func normalizeSummarySQLForParent(parent *plan.View, sqlText string) string { + normalized := strings.TrimSpace(sqlText) + if parent == nil || normalized == "" { + return normalized + } + parentName := strings.TrimSpace(parent.Name) + if parentName == "" { + return normalized + } + for _, candidate := range []string{ + "$View." + parentName + ".SQL", + "$view." + strings.ToLower(parentName) + ".sql", + } { + normalized = strings.ReplaceAll(normalized, candidate, "$View.NonWindowSQL") + } + return normalized +} + +func lookupRootView(result *plan.Result) *plan.View { + if result == nil { + return nil + } + if len(result.Views) > 0 && result.Views[0] != nil { + return result.Views[0] + } + for _, item := range result.ViewsByName { + if item != nil { + return item + } + } + return nil +} + func lookupSummaryParentView(result *plan.Result, sqlText string) *plan.View { if result == nil || strings.TrimSpace(sqlText) == "" { return nil } + if hasRootSummaryReference(sqlText) { + if len(result.Views) > 0 && result.Views[0] != nil { + return result.Views[0] + } + } parent, ok := findSummaryParentReference(sqlText) if !ok { return nil @@ -134,6 +294,13 @@ func findSummaryParentReference(input string) (string, bool) { return "", false } +func hasRootSummaryReference(input string) bool { + if strings.TrimSpace(input) == "" { + return false + } + return strings.Contains(strings.ToLower(input), "$view.nonwindowsql") +} + func isCompileIdentifierStart(ch byte) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_' } @@ -148,6 +315,8 @@ func buildViewDeclaration(item *declaredView) *plan.ViewDeclaration { } ret := &plan.ViewDeclaration{ Tag: item.Tag, + TypeName: item.TypeName, + Dest: item.Dest, Codec: item.Codec, CodecArgs: append([]string{}, item.CodecArgs...), HandlerName: item.HandlerName, @@ -177,12 +346,37 @@ func buildViewDeclaration(item *declaredView) *plan.ViewDeclaration { }) } } - if ret.Tag == "" && ret.Codec == "" && len(ret.CodecArgs) == 0 && ret.HandlerName == "" && + if len(item.ColumnsConfig) > 0 { + ret.ColumnsConfig = map[string]*plan.ViewColumnConfig{} + for name, cfg := range item.ColumnsConfig { + if strings.TrimSpace(name) == "" || cfg == nil { + continue + } + ret.ColumnsConfig[name] = &plan.ViewColumnConfig{ + DataType: strings.TrimSpace(cfg.DataType), + Tag: strings.TrimSpace(cfg.Tag), + Groupable: cloneBoolPtr(cfg.Groupable), + } + } + if len(ret.ColumnsConfig) == 0 { + ret.ColumnsConfig = nil + } + } + if ret.Tag == "" && ret.TypeName == "" && ret.Dest == "" && + ret.Codec == "" && len(ret.CodecArgs) == 0 && ret.HandlerName == "" && len(ret.HandlerArgs) == 0 && ret.StatusCode == nil && ret.ErrorMessage == "" && ret.QuerySelector == "" && ret.CacheRef == "" && ret.Limit == nil && ret.Cacheable == nil && ret.When == "" && ret.Scope == "" && ret.DataType == "" && ret.Of == "" && ret.Value == "" && - !ret.Async && !ret.Output && len(ret.Predicates) == 0 { + !ret.Async && !ret.Output && len(ret.Predicates) == 0 && len(ret.ColumnsConfig) == 0 { return nil } return ret } + +func cloneBoolPtr(value *bool) *bool { + if value == nil { + return nil + } + ret := *value + return &ret +} diff --git a/repository/shape/compile/viewdecl_options.go b/repository/shape/compile/viewdecl_options.go index dd8ea2fba..19f74b471 100644 --- a/repository/shape/compile/viewdecl_options.go +++ b/repository/shape/compile/viewdecl_options.go @@ -11,25 +11,35 @@ import ( ) func extractDeclarationSQL(fragment string) string { + sql, _ := extractDeclarationSQLWithStatus(fragment) + return sql +} + +func extractDeclarationSQLWithStatus(fragment string) (string, *int) { cursor := parsly.NewCursor("", []byte(fragment), 0) for cursor.Pos < cursor.InputSize { match := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdCommentMatcher) if match.Code == vdCommentToken { text := match.Text(cursor) if len(text) < 4 { - return "" + return "", nil } - return normalizeHintSQL(text[2 : len(text)-2]) + return normalizeHintSQLWithStatus(text[2 : len(text)-2]) } cursor.Pos++ } - return "" + return "", nil } func normalizeHintSQL(body string) string { + sql, _ := normalizeHintSQLWithStatus(body) + return sql +} + +func normalizeHintSQLWithStatus(body string) (string, *int) { body = strings.TrimSpace(body) if body == "" { - return "" + return "", nil } if strings.HasPrefix(body, "{") { if closeIdx := strings.Index(body, "}"); closeIdx != -1 { @@ -37,87 +47,85 @@ func normalizeHintSQL(body string) string { } } if body == "" { - return "" + return "", nil } switch body[0] { case '?': body = strings.TrimSpace(body[1:]) case '!': + // Deprecated: legacy `!!NNN` prefix is still supported for backward compatibility. + // Prefer explicit declaration option: .WithStatusCode(NNN). + var statusCode *int body = strings.TrimSpace(body[1:]) if strings.HasPrefix(body, "!") { body = strings.TrimSpace(body[1:]) } if len(body) >= 3 { - var status int - if _, err := fmt.Sscanf(body[:3], "%d", &status); err == nil { + var legacyStatus int + if _, err := fmt.Sscanf(body[:3], "%d", &legacyStatus); err == nil { + statusCode = &legacyStatus body = strings.TrimSpace(body[3:]) } } + return strings.TrimSpace(body), statusCode } - return strings.TrimSpace(body) + return strings.TrimSpace(body), nil } func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, diags *[]*dqlshape.Diagnostic) { if view == nil || strings.TrimSpace(tail) == "" { return } - cursor := parsly.NewCursor("", []byte(tail), 0) - for cursor.Pos < cursor.InputSize { - _ = cursor.MatchOne(vdWhitespaceMatcher) - if cursor.MatchOne(vdDotMatcher).Code != vdDotToken { - cursor.Pos++ - continue - } - _ = cursor.MatchOne(vdWhitespaceMatcher) - name, ok := readIdentifier(cursor) - if !ok { - continue - } - _ = cursor.MatchOne(vdWhitespaceMatcher) - group := cursor.MatchOne(vdExprGroupMatcher) - if group.Code != vdExprGroupToken { - continue - } - content := group.Text(cursor) - if len(content) < 2 { - continue - } - args := splitArgs(content[1 : len(content)-1]) + cursor := newOptionCursor(tail) + for cursor.next() { + name, args := cursor.option() + optionOffset := offset + cursor.start switch { case strings.EqualFold(name, "WithURI"): - if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, -1, dql, optionOffset, diags) { continue } view.URI = trimQuote(args[0]) case strings.EqualFold(name, "WithConnector"), strings.EqualFold(name, "Connector"): - if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, -1, dql, optionOffset, diags) { continue } view.Connector = trimQuote(args[0]) case strings.EqualFold(name, "Cardinality"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } card := strings.ToLower(strings.TrimSpace(trimQuote(args[0]))) switch card { case "one", "many": view.Cardinality = card + view.CardinalitySet = true default: *diags = append(*diags, &dqlshape.Diagnostic{ Code: dqldiag.CodeViewCardinality, Severity: dqlshape.SeverityWarning, Message: fmt.Sprintf("unsupported cardinality %q for declared view %q", args[0], view.Name), Hint: "use Cardinality('one') or Cardinality('many')", - Span: relationSpan(dql, offset), + Span: relationSpan(dql, optionOffset), }) } case strings.EqualFold(name, "WithTag"), strings.EqualFold(name, "Tag"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.Tag = trimQuote(args[0]) + case strings.EqualFold(name, "WithTypeName"), strings.EqualFold(name, "TypeName"), strings.EqualFold(name, "Type"): + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + view.TypeName = trimQuote(args[0]) + case strings.EqualFold(name, "WithDest"), strings.EqualFold(name, "Dest"): + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { + continue + } + view.Dest = trimQuote(args[0]) case strings.EqualFold(name, "WithCodec"), strings.EqualFold(name, "Codec"): - if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, -1, dql, optionOffset, diags) { continue } view.Codec = trimQuote(args[0]) @@ -126,7 +134,7 @@ func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, view.CodecArgs = append(view.CodecArgs, strings.TrimSpace(arg)) } case strings.EqualFold(name, "WithHandler"), strings.EqualFold(name, "Handler"): - if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, -1, dql, optionOffset, diags) { continue } view.HandlerName = trimQuote(args[0]) @@ -135,7 +143,7 @@ func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, view.HandlerArgs = append(view.HandlerArgs, strings.TrimSpace(arg)) } case strings.EqualFold(name, "WithStatusCode"), strings.EqualFold(name, "StatusCode"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } statusCode, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) @@ -145,18 +153,18 @@ func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, Severity: dqlshape.SeverityWarning, Message: fmt.Sprintf("invalid status code %q for declared view %q", args[0], view.Name), Hint: "use numeric status code, e.g. StatusCode(400)", - Span: relationSpan(dql, offset), + Span: relationSpan(dql, optionOffset), }) continue } view.StatusCode = &statusCode case strings.EqualFold(name, "WithErrorMessage"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.ErrorMessage = trimQuote(args[0]) case strings.EqualFold(name, "WithPredicate"), strings.EqualFold(name, "Predicate"): - if !expectArgs(view, name, args, 2, -1, dql, offset, diags) { + if !expectArgs(view, name, args, 2, -1, dql, optionOffset, diags) { continue } view.Predicates = append(view.Predicates, declaredPredicate{ @@ -165,7 +173,7 @@ func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, Arguments: append([]string{}, args[2:]...), }) case strings.EqualFold(name, "EnsurePredicate"): - if !expectArgs(view, name, args, 2, -1, dql, offset, diags) { + if !expectArgs(view, name, args, 2, -1, dql, optionOffset, diags) { continue } view.Predicates = append(view.Predicates, declaredPredicate{ @@ -175,7 +183,7 @@ func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, Arguments: append([]string{}, args[2:]...), }) case strings.EqualFold(name, "QuerySelector"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.QuerySelector = trimQuote(args[0]) @@ -185,69 +193,121 @@ func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, Severity: dqlshape.SeverityWarning, Message: fmt.Sprintf("query selector %q can only be used with limit, offset, page, fields, orderby", view.QuerySelector), Hint: "use QuerySelector on declarations named limit/offset/page/fields/orderby", - Span: relationSpan(dql, offset), + Span: relationSpan(dql, optionOffset), }) } case strings.EqualFold(name, "WithCache"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.CacheRef = trimQuote(args[0]) case strings.EqualFold(name, "WithLimit"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } limit, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) if err != nil { - appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid integer limit %q", args[0]), dql, offset, diags) + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid integer limit %q", args[0]), dql, optionOffset, diags) continue } view.Limit = &limit case strings.EqualFold(name, "Cacheable"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))) if err != nil { - appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid bool cacheable %q", args[0]), dql, offset, diags) + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid bool cacheable %q", args[0]), dql, optionOffset, diags) continue } view.Cacheable = &value case strings.EqualFold(name, "When"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.When = trimQuote(args[0]) case strings.EqualFold(name, "Scope"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.Scope = trimQuote(args[0]) case strings.EqualFold(name, "WithType"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.DataType = trimQuote(args[0]) + case strings.EqualFold(name, "WithColumnType"), strings.EqualFold(name, "ColumnType"): + if !expectArgs(view, name, args, 2, 2, dql, optionOffset, diags) { + continue + } + columnName := strings.TrimSpace(trimQuote(args[0])) + dataType := strings.TrimSpace(trimQuote(args[1])) + if columnName == "" || dataType == "" { + appendOptionArgDiagnostic(view, name, "column name and type must be non-empty", dql, optionOffset, diags) + continue + } + cfg := ensureDeclaredColumnConfig(view, columnName) + cfg.DataType = dataType + case strings.EqualFold(name, "WithColumnTag"), strings.EqualFold(name, "ColumnTag"): + if !expectArgs(view, name, args, 2, 2, dql, optionOffset, diags) { + continue + } + columnName := strings.TrimSpace(trimQuote(args[0])) + tag := strings.TrimSpace(trimQuote(args[1])) + if columnName == "" || tag == "" { + appendOptionArgDiagnostic(view, name, "column name and tag must be non-empty", dql, optionOffset, diags) + continue + } + cfg := ensureDeclaredColumnConfig(view, columnName) + cfg.Tag = tag + case strings.EqualFold(name, "WithColumnGroupable"), strings.EqualFold(name, "ColumnGroupable"): + if !expectArgs(view, name, args, 2, 2, dql, optionOffset, diags) { + continue + } + columnName := strings.TrimSpace(trimQuote(args[0])) + if columnName == "" { + appendOptionArgDiagnostic(view, name, "column name must be non-empty", dql, optionOffset, diags) + continue + } + groupable, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[1]))) + if err != nil { + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid bool groupable %q", args[1]), dql, optionOffset, diags) + continue + } + cfg := ensureDeclaredColumnConfig(view, columnName) + cfg.Groupable = &groupable case strings.EqualFold(name, "Of"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.Of = trimQuote(args[0]) case strings.EqualFold(name, "Value"): - if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + if !expectArgs(view, name, args, 1, 1, dql, optionOffset, diags) { continue } view.Value = trimQuote(args[0]) case strings.EqualFold(name, "Async"): - if !expectArgs(view, name, args, 0, 0, dql, offset, diags) { + if !expectArgs(view, name, args, 0, 0, dql, optionOffset, diags) { continue } view.Async = true case strings.EqualFold(name, "Output"): - if !expectArgs(view, name, args, 0, 0, dql, offset, diags) { + if !expectArgs(view, name, args, 0, 0, dql, optionOffset, diags) { continue } view.Output = true + case strings.EqualFold(name, "Required"): + if !expectArgs(view, name, args, 0, 0, dql, optionOffset, diags) { + continue + } + view.Required = true + case strings.EqualFold(name, "Optional"): + if !expectArgs(view, name, args, 0, 0, dql, optionOffset, diags) { + continue + } + view.Required = false + default: + appendOptionArgDiagnostic(view, name, "unknown option", dql, optionOffset, diags) } } } @@ -380,3 +440,15 @@ func isAllowedQuerySelector(name string) bool { return false } } + +func ensureDeclaredColumnConfig(view *declaredView, columnName string) *declaredColumnConfig { + if view.ColumnsConfig == nil { + view.ColumnsConfig = map[string]*declaredColumnConfig{} + } + cfg := view.ColumnsConfig[columnName] + if cfg == nil { + cfg = &declaredColumnConfig{} + view.ColumnsConfig[columnName] = cfg + } + return cfg +} diff --git a/repository/shape/compile/viewdecl_parse.go b/repository/shape/compile/viewdecl_parse.go index 51fd45a9d..dd6469df2 100644 --- a/repository/shape/compile/viewdecl_parse.go +++ b/repository/shape/compile/viewdecl_parse.go @@ -8,8 +8,9 @@ import ( ) type setBlock struct { - Offset int - Body string + Offset int + BodyOffset int + Body string } func extractSetBlocks(dql string) []setBlock { @@ -26,26 +27,29 @@ func extractSetBlocks(dql string) []setBlock { if group.Code != vdExprGroupToken { continue } + groupText := group.Text(cursor) + groupStart := cursor.Pos - len(groupText) body := group.Text(cursor) if len(body) < 2 { continue } result = append(result, setBlock{ - Offset: offset, - Body: body[1 : len(body)-1], + Offset: offset, + BodyOffset: groupStart + 1, + Body: body[1 : len(body)-1], }) } return result } -func parseSetDeclarationBody(body string) (holder, kind, location, tail string, ok bool) { +func parseSetDeclarationBody(body string) (holder, kind, location, tail string, tailOffset int, ok bool) { cursor := parsly.NewCursor("", []byte(body), 0) if cursor.MatchAfterOptional(vdWhitespaceMatcher, vdParamDeclMatcher).Code != vdParamDeclToken { - return "", "", "", "", false + return "", "", "", "", 0, false } id, matched := readIdentifier(cursor) if !matched { - return "", "", "", "", false + return "", "", "", "", 0, false } holder = id _ = cursor.MatchOne(vdWhitespaceMatcher) @@ -53,21 +57,22 @@ func parseSetDeclarationBody(body string) (holder, kind, location, tail string, _ = cursor.MatchOne(vdWhitespaceMatcher) kindLoc := cursor.MatchOne(vdExprGroupMatcher) if kindLoc.Code != vdExprGroupToken { - return "", "", "", "", false + return "", "", "", "", 0, false } inGroup := kindLoc.Text(cursor) if len(inGroup) < 2 { - return "", "", "", "", false + return "", "", "", "", 0, false } raw := strings.TrimSpace(inGroup[1 : len(inGroup)-1]) slash := strings.Index(raw, "/") if slash == -1 { - return "", "", "", "", false + return "", "", "", "", 0, false } kind = strings.ToLower(strings.TrimSpace(raw[:slash])) location = strings.TrimSpace(raw[slash+1:]) + tailOffset = cursor.Pos tail = strings.TrimSpace(string(cursor.Input[cursor.Pos:])) - return holder, kind, location, tail, true + return holder, kind, location, tail, tailOffset, true } func readIdentifier(cursor *parsly.Cursor) (string, bool) { diff --git a/repository/shape/compile/viewdecl_test.go b/repository/shape/compile/viewdecl_test.go index 0136c64a1..f7dff272b 100644 --- a/repository/shape/compile/viewdecl_test.go +++ b/repository/shape/compile/viewdecl_test.go @@ -1,11 +1,13 @@ package compile import ( + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" ) @@ -21,12 +23,13 @@ func TestViewDecl_ExtractSetBlocks(t *testing.T) { } func TestViewDecl_ParseSetDeclarationBody(t *testing.T) { - holder, kind, location, tail, ok := parseSetDeclarationBody("$_ = $Extra(view/extra_view).WithURI('/x')") + holder, kind, location, tail, tailOffset, ok := parseSetDeclarationBody("$_ = $Extra(view/extra_view).WithURI('/x')") require.True(t, ok) assert.Equal(t, "Extra", holder) assert.Equal(t, "view", kind) assert.Equal(t, "extra_view", location) assert.Contains(t, tail, ".WithURI('/x')") + assert.Greater(t, tailOffset, 0) } func TestViewDecl_ApplyOptions_InvalidCardinality(t *testing.T) { @@ -55,13 +58,21 @@ func TestViewDecl_AppendDeclaredViews(t *testing.T) { assert.True(t, found) } +func TestViewDecl_ExtractDeclarationSQLWithLegacyStatusPrefix(t *testing.T) { + sqlText, status := extractDeclarationSQLWithStatus("/* !!403 SELECT id FROM EXTRA e */") + assert.Equal(t, "SELECT id FROM EXTRA e", sqlText) + require.NotNil(t, status) + assert.Equal(t, 403, *status) +} + func TestViewDecl_ApplyOptions_Extended(t *testing.T) { view := &declaredView{Name: "limit"} var diags []*dqlshape.Diagnostic tail := ".WithTag('json:\"id\"').WithCodec(AsJSON,'x').WithHandler('Build',a,b)." + "WithStatusCode(422).WithErrorMessage('bad req').WithPredicate('ByID','id = ?', 101)." + "EnsurePredicate('Tenant','tenant_id = ?', 7).QuerySelector('qs').WithCache('c1').WithLimit(10)." + - "Cacheable(true).When('x > 1').Scope('team').WithType('[]Order').Of('list').Value('abc').Async().Output()" + "Cacheable(true).When('x > 1').Scope('team').Type('OrderView').Dest('orders.go').WithType('[]Order')." + + "WithColumnType('Authorized','bool').WithColumnTag('Authorized','internal:\"true\"').WithColumnGroupable('Authorized', true).Of('list').Value('abc').Async().Output()" applyDeclaredViewOptions(view, tail, "SELECT 1", 0, &diags) require.Empty(t, diags) @@ -89,7 +100,15 @@ func TestViewDecl_ApplyOptions_Extended(t *testing.T) { assert.True(t, *view.Cacheable) assert.Equal(t, "x > 1", view.When) assert.Equal(t, "team", view.Scope) + assert.Equal(t, "OrderView", view.TypeName) + assert.Equal(t, "orders.go", view.Dest) assert.Equal(t, "[]Order", view.DataType) + require.NotNil(t, view.ColumnsConfig) + require.Contains(t, view.ColumnsConfig, "Authorized") + assert.Equal(t, "bool", view.ColumnsConfig["Authorized"].DataType) + assert.Equal(t, `internal:"true"`, view.ColumnsConfig["Authorized"].Tag) + require.NotNil(t, view.ColumnsConfig["Authorized"].Groupable) + assert.True(t, *view.ColumnsConfig["Authorized"].Groupable) assert.Equal(t, "list", view.Of) assert.Equal(t, "abc", view.Value) assert.True(t, view.Async) @@ -104,6 +123,27 @@ func TestViewDecl_ApplyOptions_QuerySelectorValidation(t *testing.T) { assert.Equal(t, dqldiag.CodeDeclQuerySelector, diags[0].Code) } +func TestViewDecl_ApplyOptions_ExactSpanAndUnknownOption(t *testing.T) { + dql := "#set($_ = $Extra(view/extra).UnknownOpt('x').WithLimit('x') /* SELECT id FROM EXTRA e */)" + declared, diags := extractDeclaredViews(dql) + require.NotEmpty(t, declared) + require.Len(t, diags, 2) + assert.Equal(t, dqldiag.CodeDeclOptionArgs, diags[0].Code) + assert.Equal(t, dqldiag.CodeDeclOptionArgs, diags[1].Code) + + unknownOffset := strings.Index(dql, ".UnknownOpt") + require.GreaterOrEqual(t, unknownOffset, 0) + unknownPos := dqlpre.PointSpan(dql, unknownOffset).Start + assert.Equal(t, unknownPos.Line, diags[0].Span.Start.Line) + assert.Equal(t, unknownPos.Char, diags[0].Span.Start.Char) + + limitOffset := strings.Index(dql, ".WithLimit") + require.GreaterOrEqual(t, limitOffset, 0) + limitPos := dqlpre.PointSpan(dql, limitOffset).Start + assert.Equal(t, limitPos.Line, diags[1].Span.Start.Line) + assert.Equal(t, limitPos.Char, diags[1].Span.Start.Char) +} + func TestViewDecl_SplitArgs_Nested(t *testing.T) { args := splitArgs(`'a', fn(1,2), {'k': [1,2]}, "x,y"`) require.Len(t, args, 4) @@ -117,7 +157,8 @@ func TestViewDecl_AppendDeclaredViews_ExtendedDeclarationMetadata(t *testing.T) dql := "#set($_ = $limit(view/limit).WithTag('json:\"id\"').WithCodec(AsJSON).WithHandler('Build',a)." + "WithStatusCode(409).WithErrorMessage('conflict').WithPredicate('ByID','id=?',1)." + "EnsurePredicate('Tenant','tenant=?',2).QuerySelector('items').WithCache('c1').WithLimit(5)." + - "Cacheable(false).When('x').Scope('s').WithType('Order').Of('o').Value('v').Async().Output() /* SELECT id FROM EXTRA e */)" + "Cacheable(false).When('x').Scope('s').Type('OrderView').Dest('order.go').WithType('Order')." + + "WithColumnType('Authorized','bool').WithColumnTag('Authorized','internal:\"true\"').WithColumnGroupable('Authorized', true).Of('o').Value('v').Async().Output() /* SELECT id FROM EXTRA e */)" result := &plan.Result{ ViewsByName: map[string]*plan.View{}, ByPath: map[string]*plan.Field{}, @@ -126,7 +167,7 @@ func TestViewDecl_AppendDeclaredViews_ExtendedDeclarationMetadata(t *testing.T) require.NotEmpty(t, result.Views) var target *plan.View for _, item := range result.Views { - if item != nil && item.Name == "e" { + if item != nil && item.Name == "limit" { target = item break } @@ -147,7 +188,15 @@ func TestViewDecl_AppendDeclaredViews_ExtendedDeclarationMetadata(t *testing.T) assert.False(t, *target.Declaration.Cacheable) assert.Equal(t, "x", target.Declaration.When) assert.Equal(t, "s", target.Declaration.Scope) + assert.Equal(t, "OrderView", target.Declaration.TypeName) + assert.Equal(t, "order.go", target.Declaration.Dest) assert.Equal(t, "Order", target.Declaration.DataType) + require.NotNil(t, target.Declaration.ColumnsConfig) + require.Contains(t, target.Declaration.ColumnsConfig, "Authorized") + assert.Equal(t, "bool", target.Declaration.ColumnsConfig["Authorized"].DataType) + assert.Equal(t, `internal:"true"`, target.Declaration.ColumnsConfig["Authorized"].Tag) + require.NotNil(t, target.Declaration.ColumnsConfig["Authorized"].Groupable) + assert.True(t, *target.Declaration.ColumnsConfig["Authorized"].Groupable) assert.Equal(t, "o", target.Declaration.Of) assert.Equal(t, "v", target.Declaration.Value) assert.True(t, target.Declaration.Async) @@ -172,6 +221,99 @@ func TestViewDecl_AppendDeclaredViews_AttachSummaryFromMetaViewSQL(t *testing.T) assert.Contains(t, root.Summary, "$View.browser.SQL") } +func TestViewDecl_AppendDeclaredViews_AttachSummaryFromOutputSummarySQL(t *testing.T) { + root := &plan.View{Name: "Vendor", Path: "Vendor", Holder: "Vendor"} + result := &plan.Result{ + Views: []*plan.View{root}, + ViewsByName: map[string]*plan.View{"Vendor": root}, + ByPath: map[string]*plan.Field{}, + } + dql := "#define($_ = $Meta(output/summary) /* SELECT COUNT(1) CNT FROM ($View.vendor.SQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + require.NotNil(t, root) + assert.Contains(t, root.Summary, "COUNT(1)") + assert.Contains(t, root.Summary, "$View.vendor.SQL") +} + +func TestViewDecl_AppendDeclaredViews_AttachSummaryFromOutputSummaryNonWindowSQL(t *testing.T) { + root := &plan.View{Name: "Vendor", Path: "Vendor", Holder: "Vendor"} + result := &plan.Result{ + Views: []*plan.View{root}, + ViewsByName: map[string]*plan.View{"Vendor": root}, + ByPath: map[string]*plan.Field{}, + } + dql := "#define($_ = $Meta(output/summary) /* SELECT COUNT(1) CNT FROM ($View.NonWindowSQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + require.NotNil(t, root) + assert.Contains(t, root.Summary, "COUNT(1)") + assert.Contains(t, root.Summary, "$View.NonWindowSQL") +} + +func TestViewDecl_AppendDeclaredViews_AttachSummaryToReferencedChildView(t *testing.T) { + root := &plan.View{Name: "Vendor", Path: "Vendor", Holder: "Vendor"} + child := &plan.View{Name: "products", Path: "products", Holder: "Products"} + result := &plan.Result{ + Views: []*plan.View{root, child}, + ViewsByName: map[string]*plan.View{ + "Vendor": root, + "products": child, + }, + ByPath: map[string]*plan.Field{}, + } + dql := "#define($_ = $ProductsMeta(view/products_meta) /* SELECT COUNT(1) CNT FROM ($View.products.SQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 2) + require.NotNil(t, child) + assert.Equal(t, "ProductsMeta", child.SummaryName) + assert.Contains(t, child.Summary, "COUNT(1)") + assert.Contains(t, child.Summary, "$View.NonWindowSQL") + assert.Empty(t, root.Summary) +} + +func TestViewDecl_AppendDeclaredViews_OutputSummaryWithoutRoot_DoesNotCreateView(t *testing.T) { + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + dql := "#define($_ = $Meta(output/summary) /* SELECT COUNT(1) CNT FROM ($View.NonWindowSQL) t */)" + + appendDeclaredViews(dql, result) + + assert.Empty(t, result.Views) +} + +func TestViewDecl_RequiredImpliesOneCardinalityByDefault(t *testing.T) { + dql := "#define($_ = $Authorization(view/authorization).Required() /* SELECT Authorized FROM AUTH */)" + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + assert.Equal(t, "one", strings.ToLower(result.Views[0].Cardinality)) +} + +func TestViewDecl_ExplicitCardinalityOverridesRequiredDefault(t *testing.T) { + dql := "#define($_ = $Authorization(view/authorization).Required().Cardinality('many') /* SELECT Authorized FROM AUTH */)" + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + assert.Equal(t, "many", strings.ToLower(result.Views[0].Cardinality)) +} + func TestViewDecl_AppendDeclaredViews_MetaViewSQL_NoParentFallbackToView(t *testing.T) { result := &plan.Result{ ViewsByName: map[string]*plan.View{}, diff --git a/repository/shape/componenttag/component.go b/repository/shape/componenttag/component.go new file mode 100644 index 000000000..cd6902f6b --- /dev/null +++ b/repository/shape/componenttag/component.go @@ -0,0 +1,133 @@ +package componenttag + +import ( + "fmt" + "reflect" + "strings" + + tagtags "github.com/viant/tagly/tags" +) + +const TagName = "component" + +type Component struct { + Name string + Path string + Method string + Connector string + Marshaller string + Handler string + Input string + Output string + View string + Source string + Summary string + Report bool + ReportInput string + ReportDimensions string + ReportMeasures string + ReportFilters string + ReportOrderBy string + ReportLimit string + ReportOffset string +} + +type Tag struct { + Component *Component +} + +func (c *Component) Tag() *tagtags.Tag { + if c == nil { + return nil + } + builder := &strings.Builder{} + builder.WriteString(c.Name) + appendNonEmpty(builder, "path", c.Path) + appendNonEmpty(builder, "method", c.Method) + appendNonEmpty(builder, "connector", c.Connector) + appendNonEmpty(builder, "marshaller", c.Marshaller) + appendNonEmpty(builder, "handler", c.Handler) + appendNonEmpty(builder, "input", c.Input) + appendNonEmpty(builder, "output", c.Output) + appendNonEmpty(builder, "view", c.View) + appendNonEmpty(builder, "source", c.Source) + appendNonEmpty(builder, "summary", c.Summary) + if c.Report { + appendNonEmpty(builder, "report", "true") + } + appendNonEmpty(builder, "reportInput", c.ReportInput) + appendNonEmpty(builder, "reportDimensions", c.ReportDimensions) + appendNonEmpty(builder, "reportMeasures", c.ReportMeasures) + appendNonEmpty(builder, "reportFilters", c.ReportFilters) + appendNonEmpty(builder, "reportOrderBy", c.ReportOrderBy) + appendNonEmpty(builder, "reportLimit", c.ReportLimit) + appendNonEmpty(builder, "reportOffset", c.ReportOffset) + return &tagtags.Tag{Name: TagName, Values: tagtags.Values(builder.String())} +} + +func Parse(tag reflect.StructTag) (*Tag, error) { + tagValue, ok := tag.Lookup(TagName) + if !ok { + return &Tag{}, nil + } + name, values := tagtags.Values(tagValue).Name() + component := &Component{Name: name} + if err := values.MatchPairs(func(key, value string) error { + switch strings.ToLower(strings.TrimSpace(key)) { + case "name": + component.Name = strings.TrimSpace(value) + case "path": + component.Path = strings.TrimSpace(value) + case "method": + component.Method = strings.TrimSpace(value) + case "connector": + component.Connector = strings.TrimSpace(value) + case "marshaller": + component.Marshaller = strings.TrimSpace(value) + case "handler": + component.Handler = strings.TrimSpace(value) + case "input": + component.Input = strings.TrimSpace(value) + case "output": + component.Output = strings.TrimSpace(value) + case "view": + component.View = strings.TrimSpace(value) + case "source": + component.Source = strings.TrimSpace(value) + case "summary": + component.Summary = strings.TrimSpace(value) + case "report": + component.Report = strings.EqualFold(strings.TrimSpace(value), "true") + case "reportinput": + component.ReportInput = strings.TrimSpace(value) + case "reportdimensions": + component.ReportDimensions = strings.TrimSpace(value) + case "reportmeasures": + component.ReportMeasures = strings.TrimSpace(value) + case "reportfilters": + component.ReportFilters = strings.TrimSpace(value) + case "reportorderby": + component.ReportOrderBy = strings.TrimSpace(value) + case "reportlimit": + component.ReportLimit = strings.TrimSpace(value) + case "reportoffset": + component.ReportOffset = strings.TrimSpace(value) + default: + return fmt.Errorf("unsupported component tag option: '%s'", key) + } + return nil + }); err != nil { + return nil, err + } + return &Tag{Component: component}, nil +} + +func appendNonEmpty(builder *strings.Builder, key, value string) { + if value == "" { + return + } + builder.WriteString(",") + builder.WriteString(key) + builder.WriteString("=") + builder.WriteString(value) +} diff --git a/repository/shape/dql/decl/calls.go b/repository/shape/dql/decl/calls.go new file mode 100644 index 000000000..ac8016973 --- /dev/null +++ b/repository/shape/dql/decl/calls.go @@ -0,0 +1,97 @@ +package decl + +import ( + "strings" + + "github.com/viant/parsly" +) + +// Call represents a parsed function call with offsets in the scanned input. +type Call struct { + Name string + Args []string + Offset int + EndOffset int + Dollar bool +} + +// CallParseError represents a malformed call span. +type CallParseError struct { + Name string + Offset int + Message string +} + +// CallScanOptions controls call scanning behavior. +type CallScanOptions struct { + AllowedNames map[string]bool + RequireDollar bool + AllowDollar bool + Strict bool +} + +// ScanCalls parses function calls and returns parsed calls plus malformed-call errors. +func ScanCalls(input string, options CallScanOptions) ([]Call, []CallParseError) { + calls := make([]Call, 0) + parseErrors := make([]CallParseError, 0) + cursor := parsly.NewCursor("", []byte(input), 0) + for cursor.Pos < cursor.InputSize { + matched := cursor.MatchAfterOptional( + whitespaceMatcher, + commentBlockMatcher, + singleQuotedMatcher, + doubleQuotedMatcher, + dollarIdentifierMatcher, + identifierMatcher, + anyMatcher, + ) + switch matched.Code { + case dollarIdentifierToken, identifierToken: + rawName := matched.Text(cursor) + hasDollar := matched.Code == dollarIdentifierToken + name := strings.ToLower(strings.TrimPrefix(rawName, "$")) + if options.AllowedNames != nil && !options.AllowedNames[name] { + continue + } + if options.RequireDollar && !hasDollar { + continue + } + if !options.AllowDollar && hasDollar { + continue + } + nameOffset := matched.Offset + block := cursor.MatchAfterOptional(whitespaceMatcher, parenthesesBlockMatcher) + if block.Code != parenthesesBlockToken { + if options.Strict { + parseErrors = append(parseErrors, CallParseError{ + Name: name, + Offset: nameOffset, + Message: "invalid call syntax, expected (...)", + }) + } + continue + } + blockText := block.Text(cursor) + argsText := "" + if len(blockText) >= 2 { + argsText = blockText[1 : len(blockText)-1] + } + calls = append(calls, Call{ + Name: name, + Args: splitArgs(argsText), + Offset: nameOffset, + EndOffset: block.Offset + len(blockText), + Dollar: hasDollar, + }) + case parsly.Invalid: + if options.Strict { + parseErrors = append(parseErrors, CallParseError{ + Offset: cursor.Pos, + Message: "invalid token while scanning calls", + }) + } + cursor.Pos++ + } + } + return calls, parseErrors +} diff --git a/repository/shape/dql/decl/calls_test.go b/repository/shape/dql/decl/calls_test.go new file mode 100644 index 000000000..c467f977f --- /dev/null +++ b/repository/shape/dql/decl/calls_test.go @@ -0,0 +1,52 @@ +package decl + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScanCalls_DollarStrict(t *testing.T) { + input := "$connector('dev') $dest('a.go')" + calls, errs := ScanCalls(input, CallScanOptions{ + AllowedNames: map[string]bool{"connector": true, "dest": true}, + RequireDollar: true, + AllowDollar: true, + Strict: true, + }) + require.Empty(t, errs) + require.Len(t, calls, 2) + assert.Equal(t, "connector", calls[0].Name) + assert.Equal(t, []string{"'dev'"}, calls[0].Args) + assert.True(t, calls[0].Dollar) + assert.Equal(t, "dest", calls[1].Name) +} + +func TestScanCalls_ReportsMalformedCallOffset(t *testing.T) { + input := "$dest('a.go'" + calls, errs := ScanCalls(input, CallScanOptions{ + AllowedNames: map[string]bool{"dest": true}, + RequireDollar: true, + AllowDollar: true, + Strict: true, + }) + require.Empty(t, calls) + require.Len(t, errs, 1) + assert.Equal(t, "dest", errs[0].Name) + assert.Equal(t, 0, errs[0].Offset) +} + +func TestScanCalls_BareOnly(t *testing.T) { + input := "dest(vendor,'vendor.go'), type(vendor,'Vendor'), $dest('x.go')" + calls, errs := ScanCalls(input, CallScanOptions{ + AllowedNames: map[string]bool{"dest": true, "type": true}, + RequireDollar: false, + AllowDollar: false, + Strict: false, + }) + require.Empty(t, errs) + require.Len(t, calls, 2) + assert.Equal(t, "dest", calls[0].Name) + assert.Equal(t, "type", calls[1].Name) +} diff --git a/repository/shape/dql/decl/lex.go b/repository/shape/dql/decl/lex.go index fbf6270ac..7d9e5ed74 100644 --- a/repository/shape/dql/decl/lex.go +++ b/repository/shape/dql/decl/lex.go @@ -11,6 +11,7 @@ const ( doubleQuotedToken commentBlockToken parenthesesBlockToken + dollarIdentifierToken identifierToken anyToken ) @@ -21,6 +22,7 @@ var doubleQuotedMatcher = parsly.NewToken(doubleQuotedToken, "DoubleQuote", matc var commentBlockMatcher = parsly.NewToken(commentBlockToken, "CommentBlock", matcher.NewSeqBlock("/*", "*/")) var parenthesesBlockMatcher = parsly.NewToken(parenthesesBlockToken, "Parentheses", matcher.NewBlock('(', ')', '\\')) +var dollarIdentifierMatcher = parsly.NewToken(dollarIdentifierToken, "DollarIdentifier", &dollarIdentifierMatch{}) var identifierMatcher = parsly.NewToken(identifierToken, "Identifier", &identifierMatch{}) var anyMatcher = parsly.NewToken(anyToken, "Any", &anyMatch{}) @@ -35,6 +37,29 @@ func (a *anyMatch) Match(cursor *parsly.Cursor) int { type identifierMatch struct{} +type dollarIdentifierMatch struct{} + +func (d *dollarIdentifierMatch) Match(cursor *parsly.Cursor) int { + if cursor.Pos >= cursor.InputSize { + return 0 + } + if cursor.Input[cursor.Pos] != '$' { + return 0 + } + next := cursor.Pos + 1 + if next >= cursor.InputSize { + return 0 + } + if !isIdentifierStart(cursor.Input[next]) { + return 0 + } + pos := next + 1 + for pos < cursor.InputSize && isIdentifierPart(cursor.Input[pos]) { + pos++ + } + return pos - cursor.Pos +} + func (i *identifierMatch) Match(cursor *parsly.Cursor) int { if cursor.Pos >= cursor.InputSize { return 0 diff --git a/repository/shape/dql/diag/codes.go b/repository/shape/dql/diag/codes.go index 7fa6a96e9..70a08506e 100644 --- a/repository/shape/dql/diag/codes.go +++ b/repository/shape/dql/diag/codes.go @@ -17,6 +17,13 @@ const ( CodeDirFormat = "DQL-DIR-FORMAT" CodeDirDateFormat = "DQL-DIR-DATE-FORMAT" CodeDirCaseFormat = "DQL-DIR-CASE-FORMAT" + CodeDirConst = "DQL-DIR-CONST" + CodeDirDest = "DQL-DIR-DEST" + CodeDirInputDest = "DQL-DIR-INPUT-DEST" + CodeDirOutputDest = "DQL-DIR-OUTPUT-DEST" + CodeDirRouterDest = "DQL-DIR-ROUTER-DEST" + CodeDirInputType = "DQL-DIR-INPUT-TYPE" + CodeDirOutputType = "DQL-DIR-OUTPUT-TYPE" CodeDirUnsupported = "DQL-DIR-UNSUPPORTED" CodeOptParse = "DQL-OPT-PARSE" diff --git a/repository/shape/dql/parity/adorder_parity_test.go b/repository/shape/dql/parity/adorder_parity_test.go index 667a6c79d..9c7b639b0 100644 --- a/repository/shape/dql/parity/adorder_parity_test.go +++ b/repository/shape/dql/parity/adorder_parity_test.go @@ -9,6 +9,7 @@ import ( dqlplan "github.com/viant/datly/repository/shape/dql/plan" dqlyaml "github.com/viant/datly/repository/shape/dql/render/yaml" dqlscan "github.com/viant/datly/repository/shape/dql/scan" + "github.com/viant/datly/testutil/shapeparity" ) func TestAdorderDQL_CanonicalParityWithYAML(t *testing.T) { @@ -26,12 +27,11 @@ func TestAdorderDQL_CanonicalParityWithYAML(t *testing.T) { t.Skipf("missing fixture yaml file: %v", err) } - scanner := dqlscan.New() connectors := resolveConnectors([]string{ "ci_ads|mysql|root:dev@tcp(127.0.0.1:3307)/ci_ads?parseTime=true&charset=utf8mb4&collation=utf8mb4_bin", "ci_logs|mysql|root:dev@tcp(127.0.0.1:3307)/ci_logs?parseTime=true", }) - scanned, err := scanner.Scan(context.Background(), &dqlscan.Request{ + scanned, err := shapeparity.ScanDQL(context.Background(), &dqlscan.Request{ DQLURL: dqlPath, Repository: repoPath, ModulePrefix: "platform/adorder", diff --git a/repository/shape/dql/parity/connectors.go b/repository/shape/dql/parity/connectors.go index eaaacab8b..c85e00aeb 100644 --- a/repository/shape/dql/parity/connectors.go +++ b/repository/shape/dql/parity/connectors.go @@ -8,6 +8,18 @@ import ( // resolveConnectors returns connectors from env override, or defaults. // When DATLY_PARITY_SQLITE_DSN is set, all default connector names are mapped to sqlite3. +func splitNonEmpty(csv string) []string { + var ret []string + for _, item := range strings.Split(csv, ",") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + ret = append(ret, item) + } + return ret +} + func resolveConnectors(defaults []string) []string { if override := splitNonEmpty(os.Getenv("DATLY_PARITY_CONNECTORS")); len(override) > 0 { return override diff --git a/repository/shape/dql/parity/mdp_parity_test.go b/repository/shape/dql/parity/mdp_parity_test.go index 6941ee914..0311f2bc7 100644 --- a/repository/shape/dql/parity/mdp_parity_test.go +++ b/repository/shape/dql/parity/mdp_parity_test.go @@ -9,6 +9,7 @@ import ( dqlplan "github.com/viant/datly/repository/shape/dql/plan" dqlscan "github.com/viant/datly/repository/shape/dql/scan" + "github.com/viant/datly/testutil/shapeparity" ) func TestMDPDQL_CanonicalParityWithRoutes(t *testing.T) { @@ -41,7 +42,6 @@ func TestMDPDQL_CanonicalParityWithRoutes(t *testing.T) { msg string } var issues []issue - scanner := dqlscan.New() _ = filepath.WalkDir(routesRoot, func(path string, d os.DirEntry, walkErr error) error { if walkErr != nil || d.IsDir() { return walkErr @@ -69,7 +69,7 @@ func TestMDPDQL_CanonicalParityWithRoutes(t *testing.T) { return nil } modulePrefix := filepath.ToSlash(filepath.Join("mdp", ruleDir)) - scanned, err := scanner.Scan(context.Background(), &dqlscan.Request{ + scanned, err := shapeparity.ScanDQL(context.Background(), &dqlscan.Request{ DQLURL: dqlFile, Repository: repoRoot, ModulePrefix: modulePrefix, @@ -146,15 +146,3 @@ func envOr(key, fallback string) string { } return fallback } - -func splitNonEmpty(csv string) []string { - var ret []string - for _, item := range strings.Split(csv, ",") { - item = strings.TrimSpace(item) - if item == "" { - continue - } - ret = append(ret, item) - } - return ret -} diff --git a/repository/shape/dql/preprocess/directive_parser.go b/repository/shape/dql/preprocess/directive_parser.go index 9a13320cc..edc5e98bd 100644 --- a/repository/shape/dql/preprocess/directive_parser.go +++ b/repository/shape/dql/preprocess/directive_parser.go @@ -1,45 +1,54 @@ package preprocess -import "strings" +import ( + "strings" + + "github.com/viant/datly/repository/shape/dql/decl" +) type directiveCall struct { name string args []string start int + end int +} + +type directiveParseError struct { + name string + start int + message string } func scanDollarCalls(input string, names map[string]bool) []directiveCall { + calls, _ := scanDollarCallsStrict(input, names) + return calls +} + +func scanDollarCallsStrict(input string, names map[string]bool) ([]directiveCall, []directiveParseError) { + parsed, parseErrors := decl.ScanCalls(input, decl.CallScanOptions{ + AllowedNames: names, + RequireDollar: true, + AllowDollar: true, + Strict: true, + }) result := make([]directiveCall, 0) - for i := 0; i < len(input); { - if input[i] != '$' || i+1 >= len(input) || !isIdentifierStart(input[i+1]) { - i++ - continue - } - start := i + 1 - i += 2 - for i < len(input) && isIdentifierPart(input[i]) { - i++ - } - name := strings.ToLower(input[start:i]) - if !names[name] { - continue - } - j := skipSpaces(input, i) - if j >= len(input) || input[j] != '(' { - continue - } - body, end, ok := readCallBody(input, j) - if !ok { - continue - } + for _, call := range parsed { result = append(result, directiveCall{ - name: name, - args: splitCallArgs(body), - start: start - 1, + name: call.Name, + args: call.Args, + start: call.Offset, + end: call.EndOffset, + }) + } + errs := make([]directiveParseError, 0, len(parseErrors)) + for _, parseErr := range parseErrors { + errs = append(errs, directiveParseError{ + name: parseErr.Name, + start: parseErr.Offset, + message: parseErr.Message, }) - i = end + 1 } - return result + return result, errs } func readCallBody(input string, openParen int) (string, int, bool) { diff --git a/repository/shape/dql/preprocess/extract.go b/repository/shape/dql/preprocess/extract.go index 67b761a41..e73308041 100644 --- a/repository/shape/dql/preprocess/extract.go +++ b/repository/shape/dql/preprocess/extract.go @@ -18,11 +18,13 @@ func extractSQLAndContext(dql string) (string, *typectx.Context, *dqlshape.Direc blocks := extractSetDirectiveBlocks(dql) for _, block := range blocks { - applyMask(mask, dql, block.start, block.end) + if shouldMaskDirectiveBlock(block) { + applyMask(mask, dql, block.start, block.end) + } if block.kind != directiveSettings { continue } - diagnostics = append(diagnostics, parseSettingsDirectives(block.body, dql, block.start, directives)...) + diagnostics = append(diagnostics, parseSettingsDirectives(block.body, dql, block.bodyStart, directives)...) } lines := strings.SplitAfter(dql, "\n") @@ -43,6 +45,10 @@ func extractSQLAndContext(dql string) (string, *typectx.Context, *dqlshape.Direc } if kind := lineDirectiveKind(trimmed); kind != directiveUnknown { if !hasMasked(mask, lineStart, lineEnd) { + if kind != directiveSettings && !shouldMaskDirectiveLine(kind, trimmed) { + offset += len(line) + continue + } if kind != directiveSettings { applyMask(mask, dql, lineStart, lineEnd) offset += len(line) @@ -72,6 +78,38 @@ func extractSQLAndContext(dql string) (string, *typectx.Context, *dqlshape.Direc return string(masked), ctx, directives, diagnostics } +func shouldMaskDirectiveBlock(block setDirectiveBlock) bool { + switch block.kind { + case directiveSettings, directiveDefine: + return true + case directiveSet: + return isDeclarationDirectiveBody(block.body) + default: + return false + } +} + +func shouldMaskDirectiveLine(kind directiveKind, line string) bool { + switch kind { + case directiveSettings, directiveDefine: + return true + case directiveSet: + start := strings.Index(line, "(") + end := strings.LastIndex(line, ")") + if start == -1 || end <= start { + return false + } + return isDeclarationDirectiveBody(line[start+1 : end]) + default: + return false + } +} + +func isDeclarationDirectiveBody(body string) bool { + text := strings.TrimSpace(body) + return strings.HasPrefix(text, "$_") +} + func applyMask(mask []bool, text string, start, end int) { if start < 0 { start = 0 diff --git a/repository/shape/dql/preprocess/preprocess.go b/repository/shape/dql/preprocess/preprocess.go index 9579dfb0c..d932a1039 100644 --- a/repository/shape/dql/preprocess/preprocess.go +++ b/repository/shape/dql/preprocess/preprocess.go @@ -41,12 +41,30 @@ func Prepare(dql string) *Result { ret.Optimized = optimized sanitized := dqlsanitize.Rewrite(optimized, dqlsanitize.Options{ Declared: dqlsanitize.Declared(optimized), + Foreach: dqlsanitize.ForeachDeclared(optimized), + Consts: constNames(ret.Directives), }) ret.SQL = sanitized.SQL ret.Mapper = newMapper(len(optimized), sanitized.Patches, sanitized.TrimPrefix, dql) return ret } +func constNames(directives *dqlshape.Directives) map[string]bool { + if directives == nil || len(directives.Const) == 0 { + return nil + } + result := make(map[string]bool, len(directives.Const)) + for name := range directives.Const { + if trimmed := strings.TrimSpace(name); trimmed != "" { + result[trimmed] = true + } + } + if len(result) == 0 { + return nil + } + return result +} + func stripDecorators(sql string) string { if strings.TrimSpace(sql) == "" { return sql @@ -72,6 +90,9 @@ func isStandaloneDecoratorLine(line string) bool { if open <= 0 || close <= open { return false } + if strings.TrimSpace(trimmed[close+1:]) != "" { + return false + } name := strings.ToLower(strings.TrimSpace(trimmed[:open])) switch name { case "use_connector", "allow_nulls", "allownulls", "tag", "cast", "required", "cardinality", "set_limit": @@ -115,6 +136,13 @@ func normalizeDirectives(input *dqlshape.Directives) *dqlshape.Directives { ret := &dqlshape.Directives{ Meta: strings.TrimSpace(input.Meta), DefaultConnector: strings.TrimSpace(input.DefaultConnector), + TemplateType: strings.TrimSpace(input.TemplateType), + Dest: strings.TrimSpace(input.Dest), + InputDest: strings.TrimSpace(input.InputDest), + OutputDest: strings.TrimSpace(input.OutputDest), + RouterDest: strings.TrimSpace(input.RouterDest), + InputType: strings.TrimSpace(input.InputType), + OutputType: strings.TrimSpace(input.OutputType), JSONMarshalType: strings.TrimSpace(input.JSONMarshalType), JSONUnmarshalType: strings.TrimSpace(input.JSONUnmarshalType), XMLUnmarshalType: strings.TrimSpace(input.XMLUnmarshalType), @@ -124,8 +152,12 @@ func normalizeDirectives(input *dqlshape.Directives) *dqlshape.Directives { } if input.Cache != nil { ret.Cache = &dqlshape.CacheDirective{ - Enabled: input.Cache.Enabled, - TTL: strings.TrimSpace(input.Cache.TTL), + Enabled: input.Cache.Enabled, + TTL: strings.TrimSpace(input.Cache.TTL), + Name: strings.TrimSpace(input.Cache.Name), + Provider: strings.TrimSpace(input.Cache.Provider), + Location: strings.TrimSpace(input.Cache.Location), + TimeToLiveMs: input.Cache.TimeToLiveMs, } } if input.MCP != nil { @@ -147,9 +179,30 @@ func normalizeDirectives(input *dqlshape.Directives) *dqlshape.Directives { Methods: normalizedMethods, } } - if ret.Meta == "" && ret.DefaultConnector == "" && ret.Cache == nil && ret.MCP == nil && ret.Route == nil && + if input.Report != nil { + ret.Report = &dqlshape.ReportDirective{ + Enabled: input.Report.Enabled, + Input: strings.TrimSpace(input.Report.Input), + Dimensions: strings.TrimSpace(input.Report.Dimensions), + Measures: strings.TrimSpace(input.Report.Measures), + Filters: strings.TrimSpace(input.Report.Filters), + OrderBy: strings.TrimSpace(input.Report.OrderBy), + Limit: strings.TrimSpace(input.Report.Limit), + Offset: strings.TrimSpace(input.Report.Offset), + } + } + if len(input.Const) > 0 { + ret.Const = make(map[string]string, len(input.Const)) + for k, v := range input.Const { + ret.Const[k] = v + } + } + if ret.Meta == "" && ret.DefaultConnector == "" && ret.TemplateType == "" && + ret.Dest == "" && ret.InputDest == "" && ret.OutputDest == "" && ret.RouterDest == "" && + ret.InputType == "" && ret.OutputType == "" && + ret.Cache == nil && ret.MCP == nil && ret.Route == nil && ret.Report == nil && ret.JSONMarshalType == "" && ret.JSONUnmarshalType == "" && ret.XMLUnmarshalType == "" && ret.Format == "" && - ret.DateFormat == "" && ret.CaseFormat == "" { + ret.DateFormat == "" && ret.CaseFormat == "" && len(ret.Const) == 0 { return nil } return ret diff --git a/repository/shape/dql/preprocess/preprocess_test.go b/repository/shape/dql/preprocess/preprocess_test.go index 337b9ca41..29e06bcb4 100644 --- a/repository/shape/dql/preprocess/preprocess_test.go +++ b/repository/shape/dql/preprocess/preprocess_test.go @@ -1,6 +1,7 @@ package preprocess import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -53,6 +54,58 @@ FROM t` assert.NotContains(t, pre.DirectSQL, ",\nFROM") } +func TestPrepare_PreservesSQLCastProjection(t *testing.T) { + dql := `SELECT + CAST($Var3 AS SIGNED) AS Key3, + cast(status, 'int') +FROM t` + pre := Prepare(dql) + require.NotNil(t, pre) + assert.Contains(t, pre.DirectSQL, "CAST($Var3 AS SIGNED) AS Key3") + assert.NotContains(t, pre.DirectSQL, "cast(status, 'int')") +} + +func TestPrepare_PreservesExecControlDirectives(t *testing.T) { + dql := "#define($_ = $Ids<[]int>(body/Ids))\n" + + "#foreach($rec in $Unsafe.Records)\n" + + "#if($rec.IS_AUTH == 0)\n" + + "$logger.Fatal('x')\n" + + "#else\n" + + "UPDATE PRODUCT SET STATUS = $Status WHERE ID = $rec.ID;\n" + + "#end\n" + + "#end" + pre := Prepare(dql) + require.NotNil(t, pre) + assert.Contains(t, pre.SQL, "#foreach($rec in $Unsafe.Records)") + assert.Contains(t, pre.SQL, "#if($rec.IS_AUTH == 0)") + assert.Contains(t, pre.SQL, "#else") + assert.Contains(t, pre.SQL, "#end") +} + +func TestPrepare_PreservesLocalSetDirectivesInExecTemplate(t *testing.T) { + dql := "#define($_ = $Ids<[]int>(query/Ids))\n" + + "#set($byID = $Unsafe.Rows.IndexBy(\"ID\"))\n" + + "#foreach($id in $Unsafe.Ids)\n" + + " #set($row = $byID[$id])\n" + + " UPDATE T SET ACTIVE = 0 WHERE ID = $id;\n" + + "#end" + pre := Prepare(dql) + require.NotNil(t, pre) + assert.Contains(t, pre.SQL, "#set($byID = $Unsafe.Rows.IndexBy(\"ID\"))") + assert.Contains(t, pre.SQL, "#set($row = $byID[$id])") + assert.NotContains(t, pre.SQL, "#define($_ = $Ids<[]int>(query/Ids))") +} + +func TestPrepare_ConstDirective_UsesUnsafeSelectors(t *testing.T) { + dql := "#setting($_ = $const('Vendor','VENDOR'))\n" + + "SELECT * FROM ${Vendor} t WHERE t.ID = $id" + pre := Prepare(dql) + require.NotNil(t, pre) + assert.Contains(t, pre.SQL, "FROM ${Unsafe.Vendor} t") + assert.Contains(t, pre.SQL, "$criteria.AppendBinding($Unsafe.id)") + assert.NotContains(t, pre.SQL, "${criteria.AppendBinding($Unsafe.Vendor)}") +} + func TestPrepare_MultilineSetDirective_TypeContext(t *testing.T) { dql := "#package('a/b')\n#import('x','github.com/acme/x')\nSELECT id FROM t" pre := Prepare(dql) @@ -78,6 +131,13 @@ func TestPrepare_InvalidMultilineImportDiagnostic(t *testing.T) { func TestPrepare_SpecialDirectives(t *testing.T) { dql := "#settings($_ = $meta('docs/orders.md'))\n" + "#setting($_ = $connector('analytics'))\n" + + "#setting($_ = $report('OrderReportInput','Dims','Metrics','Predicates','Sort','Take','Skip'))\n" + + "#setting($_ = $dest('vendor.go'))\n" + + "#setting($_ = $input_dest('vendor_input.go'))\n" + + "#setting($_ = $output_dest('vendor_output.go'))\n" + + "#setting($_ = $router_dest('vendor_router.go'))\n" + + "#setting($_ = $input_type('VendorInput'))\n" + + "#setting($_ = $output_type('VendorOutput'))\n" + "#settings($_ = $cache(true, '5m'))\n" + "#settings($_ = $mcp('orders.search', 'Search orders', 'docs/mcp/orders.md'))\n" + "#settings($_ = $marshal('application/json','pkg.OrderJSON'))\n" + @@ -86,12 +146,28 @@ func TestPrepare_SpecialDirectives(t *testing.T) { "#settings($_ = $format('tabular_json'))\n" + "#settings($_ = $date_format('2006-01-02'))\n" + "#settings($_ = $case_format('lc'))\n" + + "#settings($_ = $useTemplate('patch'))\n" + "SELECT id FROM ORDERS o" pre := Prepare(dql) require.NotNil(t, pre) require.NotNil(t, pre.Directives) assert.Equal(t, "docs/orders.md", pre.Directives.Meta) assert.Equal(t, "analytics", pre.Directives.DefaultConnector) + require.NotNil(t, pre.Directives.Report) + assert.True(t, pre.Directives.Report.Enabled) + assert.Equal(t, "OrderReportInput", pre.Directives.Report.Input) + assert.Equal(t, "Dims", pre.Directives.Report.Dimensions) + assert.Equal(t, "Metrics", pre.Directives.Report.Measures) + assert.Equal(t, "Predicates", pre.Directives.Report.Filters) + assert.Equal(t, "Sort", pre.Directives.Report.OrderBy) + assert.Equal(t, "Take", pre.Directives.Report.Limit) + assert.Equal(t, "Skip", pre.Directives.Report.Offset) + assert.Equal(t, "vendor.go", pre.Directives.Dest) + assert.Equal(t, "vendor_input.go", pre.Directives.InputDest) + assert.Equal(t, "vendor_output.go", pre.Directives.OutputDest) + assert.Equal(t, "vendor_router.go", pre.Directives.RouterDest) + assert.Equal(t, "VendorInput", pre.Directives.InputType) + assert.Equal(t, "VendorOutput", pre.Directives.OutputType) require.NotNil(t, pre.Directives.Cache) assert.True(t, pre.Directives.Cache.Enabled) assert.Equal(t, "5m", pre.Directives.Cache.TTL) @@ -105,6 +181,28 @@ func TestPrepare_SpecialDirectives(t *testing.T) { assert.Equal(t, "tabular", pre.Directives.Format) assert.Equal(t, "2006-01-02", pre.Directives.DateFormat) assert.Equal(t, "lc", pre.Directives.CaseFormat) + assert.Equal(t, "patch", pre.Directives.TemplateType) +} + +func TestPrepare_InvalidDestDirectiveDiagnostic(t *testing.T) { + dql := "SELECT 1\n#settings($_ = $dest())" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirDest, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) +} + +func TestPrepare_CacheProviderDirective(t *testing.T) { + dql := "#setting($_ = $cache('aerospike').WithProvider('aerospike://127.0.0.1:3000/test').WithLocation('${view.Name}').WithTimeToLiveMs(3600000))\nSELECT 1" + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotNil(t, pre.Directives) + require.NotNil(t, pre.Directives.Cache) + assert.Equal(t, "aerospike", pre.Directives.Cache.Name) + assert.Equal(t, "aerospike://127.0.0.1:3000/test", pre.Directives.Cache.Provider) + assert.Equal(t, "${view.Name}", pre.Directives.Cache.Location) + assert.Equal(t, 3600000, pre.Directives.Cache.TimeToLiveMs) } func TestPrepare_InvalidSpecialDirectiveDiagnostic(t *testing.T) { @@ -125,6 +223,28 @@ func TestPrepare_InvalidConnectorDirectiveDiagnostic(t *testing.T) { assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) } +func TestPrepare_InvalidDirective_UsesExactCallSpan(t *testing.T) { + lineText := "#settings($_ = $dest())" + dql := "SELECT 1\n" + lineText + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirDest, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) + assert.Equal(t, strings.Index(lineText, "$dest(")+1, pre.Diagnostics[0].Span.Start.Char) +} + +func TestPrepare_MalformedDirective_UsesExactCallSpan(t *testing.T) { + lineText := "#settings($_ = $dest('x'" + dql := "SELECT 1\n" + lineText + pre := Prepare(dql) + require.NotNil(t, pre) + require.NotEmpty(t, pre.Diagnostics) + assert.Equal(t, dqldiag.CodeDirDest, pre.Diagnostics[0].Code) + assert.Equal(t, 2, pre.Diagnostics[0].Span.Start.Line) + assert.Equal(t, strings.Index(lineText, "$dest(")+1, pre.Diagnostics[0].Span.Start.Char) +} + func TestPrepare_RouteDirective(t *testing.T) { dql := "SELECT 1\n#settings($_ = $route('/v1/api/orders', 'GET', 'POST', 'PATCH'))" pre := Prepare(dql) diff --git a/repository/shape/dql/preprocess/scanner.go b/repository/shape/dql/preprocess/scanner.go index 3b7513917..f60031149 100644 --- a/repository/shape/dql/preprocess/scanner.go +++ b/repository/shape/dql/preprocess/scanner.go @@ -16,10 +16,11 @@ var ( ) type setDirectiveBlock struct { - start int - end int - body string - kind directiveKind + start int + end int + bodyStart int + body string + kind directiveKind } type directiveKind int @@ -41,9 +42,6 @@ func isDirectiveLine(line string) bool { if isSetLine(line) { return true } - if strings.HasPrefix(line, "#if(") || strings.HasPrefix(line, "#elseif(") || strings.HasPrefix(line, "#else") || strings.HasPrefix(line, "#end") { - return true - } return false } @@ -76,10 +74,11 @@ func extractSetDirectiveBlocks(dql string) []setDirectiveBlock { } end := cursor.Pos result = append(result, setDirectiveBlock{ - start: start, - end: end, - body: groupText[1 : len(groupText)-1], - kind: kind, + start: start, + end: end, + bodyStart: group.Offset + 1, + body: groupText[1 : len(groupText)-1], + kind: kind, }) } return result diff --git a/repository/shape/dql/preprocess/settings_directives.go b/repository/shape/dql/preprocess/settings_directives.go index 3d7793c8f..4327a7780 100644 --- a/repository/shape/dql/preprocess/settings_directives.go +++ b/repository/shape/dql/preprocess/settings_directives.go @@ -2,6 +2,8 @@ package preprocess import ( "net/http" + "regexp" + "strconv" "strings" "github.com/viant/datly/repository/content" @@ -11,16 +13,28 @@ import ( ) var ( - metaDirectiveName = map[string]bool{"meta": true} - connectorDirectiveName = map[string]bool{"connector": true} - cacheDirectiveName = map[string]bool{"cache": true} - mcpDirectiveName = map[string]bool{"mcp": true} - routeDirectiveName = map[string]bool{"route": true} - marshalDirectiveName = map[string]bool{"marshal": true} - unmarshalDirectiveName = map[string]bool{"unmarshal": true} - formatDirectiveName = map[string]bool{"format": true} - dateFormatDirectiveName = map[string]bool{"date_format": true} - caseFormatDirectiveName = map[string]bool{"case_format": true} + metaDirectiveName = map[string]bool{"meta": true} + connectorDirectiveName = map[string]bool{"connector": true} + cacheDirectiveName = map[string]bool{"cache": true} + mcpDirectiveName = map[string]bool{"mcp": true} + routeDirectiveName = map[string]bool{"route": true} + reportDirectiveName = map[string]bool{"report": true} + constDirectiveName = map[string]bool{"const": true} + marshalDirectiveName = map[string]bool{"marshal": true} + unmarshalDirectiveName = map[string]bool{"unmarshal": true} + formatDirectiveName = map[string]bool{"format": true} + dateFormatDirectiveName = map[string]bool{"date_format": true} + caseFormatDirectiveName = map[string]bool{"case_format": true} + useTemplateDirectiveName = map[string]bool{"usetemplate": true} + destDirectiveName = map[string]bool{"dest": true} + inputDestDirectiveName = map[string]bool{"input_dest": true} + outputDestDirectiveName = map[string]bool{"output_dest": true} + routerDestDirectiveName = map[string]bool{"router_dest": true} + inputTypeDirectiveName = map[string]bool{"input_type": true} + outputTypeDirectiveName = map[string]bool{"output_type": true} + cacheProviderExpr = regexp.MustCompile(`(?i)\.withprovider\s*\(\s*['"]([^'"]+)['"]\s*\)`) + cacheLocationExpr = regexp.MustCompile(`(?i)\.withlocation\s*\(\s*['"]([^'"]+)['"]\s*\)`) + cacheTTLMsExpr = regexp.MustCompile(`(?i)\.withtimetolivems\s*\(\s*([0-9]+)\s*\)`) ) func parseSettingsDirectives(input, fullDQL string, diagnosticOffset int, directives *dqlshape.Directives) []*dqlshape.Diagnostic { @@ -39,57 +53,114 @@ func parseSettingsDirectives(input, fullDQL string, diagnosticOffset int, direct )) } if strings.Contains(lower, "$meta") { - values := parseMetaDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, metaDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirMeta, fullDQL, diagnosticOffset) + values := parseMetaDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMeta, "invalid $meta directive", "expected: #settings($_ = $meta('relative/or/absolute/path'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMeta, "invalid $meta directive", "expected: #settings($_ = $meta('relative/or/absolute/path'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.Meta = values[len(values)-1] } } if strings.Contains(lower, "$connector") { - values := parseConnectorDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, connectorDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirConnector, fullDQL, diagnosticOffset) + values := parseConnectorDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirConnector, "invalid $connector directive", "expected: #settings($_ = $connector('connector_name'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirConnector, "invalid $connector directive", "expected: #settings($_ = $connector('connector_name'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.DefaultConnector = values[len(values)-1] } } if strings.Contains(lower, "$cache") { - values := parseCacheDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, cacheDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirCache, fullDQL, diagnosticOffset) + values := parseCacheDirectiveCalls(input, calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirCache, "invalid $cache directive", "expected: #settings($_ = $cache(true, '5m'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirCache, "invalid $cache directive", "expected: #settings($_ = $cache(true, '5m')) or #setting($_ = $cache('name').WithProvider('...').WithLocation('...').WithTimeToLiveMs(1000))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.Cache = values[len(values)-1] } } if strings.Contains(lower, "$mcp") { - values := parseMCPDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, mcpDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirMCP, fullDQL, diagnosticOffset) + values := parseMCPDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMCP, "invalid $mcp directive", "expected: #settings($_ = $mcp('tool.name','description','docs/path.md'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMCP, "invalid $mcp directive", "expected: #settings($_ = $mcp('tool.name','description','docs/path.md'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.MCP = values[len(values)-1] } } if strings.Contains(lower, "$route") { - values := parseRouteDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, routeDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirRoute, fullDQL, diagnosticOffset) + values := parseRouteDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirRoute, "invalid $route directive", "expected: #settings($_ = $route('/v1/api/path','GET','POST'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirRoute, "invalid $route directive", "expected: #settings($_ = $route('/v1/api/path','GET','POST'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.Route = values[len(values)-1] } } + if strings.Contains(lower, "$report") { + calls, parseErrors := scanDollarCallsStrict(input, reportDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirRoute, fullDQL, diagnosticOffset) + values := parseReportDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirRoute, "invalid $report directive", "expected: #settings($_ = $report()) or #settings($_ = $report('InputType','Dimensions','Measures','Filters','OrderBy','Limit','Offset'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.Report = values[len(values)-1] + } + } + if strings.Contains(lower, "$const") { + calls, parseErrors := scanDollarCallsStrict(input, constDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirConst, fullDQL, diagnosticOffset) + values := parseConstDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirConst, "invalid $const directive", "expected: #settings($_ = $const('Name','VALUE'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + if directives.Const == nil { + directives.Const = map[string]string{} + } + for _, kv := range values { + directives.Const[kv[0]] = kv[1] + } + } + } if strings.Contains(lower, "$marshal") { - values := parseMarshalDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, marshalDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirMarshal, fullDQL, diagnosticOffset) + values := parseMarshalDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMarshal, "invalid $marshal directive", "expected: #settings($_ = $marshal('application/json','pkg.Type'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirMarshal, "invalid $marshal directive", "expected: #settings($_ = $marshal('application/json','pkg.Type'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.JSONMarshalType = values[len(values)-1] } } if strings.Contains(lower, "$unmarshal") { - values := parseUnmarshalDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, unmarshalDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirUnmarshal, fullDQL, diagnosticOffset) + values := parseUnmarshalDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirUnmarshal, "invalid $unmarshal directive", "expected: #settings($_ = $unmarshal('application/json','pkg.Type'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirUnmarshal, "invalid $unmarshal directive", "expected: #settings($_ = $unmarshal('application/json','pkg.Type'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { last := values[len(values)-1] if last.JSONType != "" { @@ -101,34 +172,199 @@ func parseSettingsDirectives(input, fullDQL string, diagnosticOffset int, direct } } if strings.Contains(lower, "$format") { - values := parseFormatDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, formatDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirFormat, fullDQL, diagnosticOffset) + values := parseFormatDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirFormat, "invalid $format directive", "expected: #settings($_ = $format('tabular_json'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirFormat, "invalid $format directive", "expected: #settings($_ = $format('tabular_json'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.Format = values[len(values)-1] } } if strings.Contains(lower, "$date_format") { - values := parseDateFormatDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, dateFormatDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirDateFormat, fullDQL, diagnosticOffset) + values := parseDateFormatDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirDateFormat, "invalid $date_format directive", "expected: #settings($_ = $date_format('2006-01-02'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirDateFormat, "invalid $date_format directive", "expected: #settings($_ = $date_format('2006-01-02'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.DateFormat = values[len(values)-1] } } if strings.Contains(lower, "$case_format") { - values := parseCaseFormatDirectives(input) + calls, parseErrors := scanDollarCallsStrict(input, caseFormatDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirCaseFormat, fullDQL, diagnosticOffset) + values := parseCaseFormatDirectiveCalls(calls) if len(values) == 0 { - diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirCaseFormat, "invalid $case_format directive", "expected: #settings($_ = $case_format('lc'))", fullDQL, diagnosticOffset)) + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirCaseFormat, "invalid $case_format directive", "expected: #settings($_ = $case_format('lc'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } } else { directives.CaseFormat = values[len(values)-1] } } + if strings.Contains(lower, "$usetemplate") { + calls, parseErrors := scanDollarCallsStrict(input, useTemplateDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirUnsupported, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirUnsupported, "invalid $useTemplate directive", "expected: #settings($_ = $useTemplate('patch'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.TemplateType = values[len(values)-1] + } + } + if strings.Contains(lower, "$dest") { + calls, parseErrors := scanDollarCallsStrict(input, destDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirDest, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirDest, "invalid $dest directive", "expected: #settings($_ = $dest('file.go'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.Dest = values[len(values)-1] + } + } + if strings.Contains(lower, "$input_dest") { + calls, parseErrors := scanDollarCallsStrict(input, inputDestDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirInputDest, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirInputDest, "invalid $input_dest directive", "expected: #settings($_ = $input_dest('input.go'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.InputDest = values[len(values)-1] + } + } + if strings.Contains(lower, "$output_dest") { + calls, parseErrors := scanDollarCallsStrict(input, outputDestDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirOutputDest, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirOutputDest, "invalid $output_dest directive", "expected: #settings($_ = $output_dest('output.go'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.OutputDest = values[len(values)-1] + } + } + if strings.Contains(lower, "$router_dest") { + calls, parseErrors := scanDollarCallsStrict(input, routerDestDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirRouterDest, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirRouterDest, "invalid $router_dest directive", "expected: #settings($_ = $router_dest('router.go'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.RouterDest = values[len(values)-1] + } + } + if strings.Contains(lower, "$input_type") { + calls, parseErrors := scanDollarCallsStrict(input, inputTypeDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirInputType, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirInputType, "invalid $input_type directive", "expected: #settings($_ = $input_type('TypeName'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.InputType = values[len(values)-1] + } + } + if strings.Contains(lower, "$output_type") { + calls, parseErrors := scanDollarCallsStrict(input, outputTypeDirectiveName) + diagnostics = appendDirectiveParseErrors(diagnostics, parseErrors, dqldiag.CodeDirOutputType, fullDQL, diagnosticOffset) + values := parseSingleArgQuotedDirectiveCalls(calls) + if len(values) == 0 { + if len(calls) > 0 { + diagnostics = append(diagnostics, directiveDiagnostic(dqldiag.CodeDirOutputType, "invalid $output_type directive", "expected: #settings($_ = $output_type('TypeName'))", fullDQL, lastDirectiveCallOffset(calls, diagnosticOffset))) + } + } else { + directives.OutputType = values[len(values)-1] + } + } return diagnostics } +func appendDirectiveParseErrors(diagnostics []*dqlshape.Diagnostic, parseErrors []directiveParseError, code, fullDQL string, diagnosticOffset int) []*dqlshape.Diagnostic { + for _, parseErr := range parseErrors { + message := "invalid directive syntax" + if parseErr.name != "" { + message = "invalid $" + parseErr.name + " directive" + } + diagnostics = append(diagnostics, directiveDiagnostic(code, message, "fix malformed directive call syntax", fullDQL, diagnosticOffset+parseErr.start)) + } + return diagnostics +} + +func lastDirectiveCallOffset(calls []directiveCall, diagnosticOffset int) int { + if len(calls) == 0 { + return diagnosticOffset + } + return diagnosticOffset + calls[len(calls)-1].start +} + +func parseDestDirectives(input string) []string { + calls := scanDollarCalls(input, destDirectiveName) + return parseSingleArgQuotedDirectiveCalls(calls) +} + +func parseInputDestDirectives(input string) []string { + calls := scanDollarCalls(input, inputDestDirectiveName) + return parseSingleArgQuotedDirectiveCalls(calls) +} + +func parseOutputDestDirectives(input string) []string { + calls := scanDollarCalls(input, outputDestDirectiveName) + return parseSingleArgQuotedDirectiveCalls(calls) +} + +func parseRouterDestDirectives(input string) []string { + calls := scanDollarCalls(input, routerDestDirectiveName) + return parseSingleArgQuotedDirectiveCalls(calls) +} + +func parseInputTypeDirectives(input string) []string { + calls := scanDollarCalls(input, inputTypeDirectiveName) + return parseSingleArgQuotedDirectiveCalls(calls) +} + +func parseOutputTypeDirectives(input string) []string { + calls := scanDollarCalls(input, outputTypeDirectiveName) + return parseSingleArgQuotedDirectiveCalls(calls) +} + +func parseSingleArgQuotedDirectiveCalls(calls []directiveCall) []string { + result := make([]string, 0, len(calls)) + for _, call := range calls { + if len(call.args) != 1 { + continue + } + value, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + if value = strings.TrimSpace(value); value != "" { + result = append(result, value) + } + } + return result +} + func parseMetaDirectives(input string) []string { calls := scanDollarCalls(input, metaDirectiveName) + return parseMetaDirectiveCalls(calls) +} + +func parseMetaDirectiveCalls(calls []directiveCall) []string { result := make([]string, 0, len(calls)) for _, call := range calls { if len(call.args) != 1 { @@ -147,6 +383,10 @@ func parseMetaDirectives(input string) []string { func parseConnectorDirectives(input string) []string { calls := scanDollarCalls(input, connectorDirectiveName) + return parseConnectorDirectiveCalls(calls) +} + +func parseConnectorDirectiveCalls(calls []directiveCall) []string { result := make([]string, 0, len(calls)) for _, call := range calls { if len(call.args) != 1 { @@ -165,36 +405,81 @@ func parseConnectorDirectives(input string) []string { func parseCacheDirectives(input string) []*dqlshape.CacheDirective { calls := scanDollarCalls(input, cacheDirectiveName) + return parseCacheDirectiveCalls(input, calls) +} + +func parseCacheDirectiveCalls(input string, calls []directiveCall) []*dqlshape.CacheDirective { result := make([]*dqlshape.CacheDirective, 0, len(calls)) for _, call := range calls { if len(call.args) == 0 || len(call.args) > 2 { continue } - enabledRaw := strings.TrimSpace(call.args[0]) - var enabled bool - switch { - case strings.EqualFold(enabledRaw, "true"): - enabled = true - case strings.EqualFold(enabledRaw, "false"): - enabled = false - default: + firstArg := strings.TrimSpace(call.args[0]) + if strings.EqualFold(firstArg, "true") || strings.EqualFold(firstArg, "false") { + ttl := "" + if len(call.args) == 2 { + value, ok := parseQuotedLiteral(call.args[1]) + if !ok { + continue + } + ttl = strings.TrimSpace(value) + } + result = append(result, &dqlshape.CacheDirective{ + Enabled: strings.EqualFold(firstArg, "true"), + TTL: ttl, + }) + continue + } + name, ok := parseQuotedLiteral(firstArg) + if !ok { + continue + } + name = strings.TrimSpace(name) + if name == "" { + continue + } + tail := "" + if call.end > 0 && call.end <= len(input) { + tail = input[call.end:] + } + cacheDirective := &dqlshape.CacheDirective{ + Enabled: true, + Name: name, + } + if match := cacheProviderExpr.FindStringSubmatch(tail); len(match) > 1 { + cacheDirective.Provider = strings.TrimSpace(match[1]) + } + if match := cacheLocationExpr.FindStringSubmatch(tail); len(match) > 1 { + cacheDirective.Location = strings.TrimSpace(match[1]) + } + if match := cacheTTLMsExpr.FindStringSubmatch(tail); len(match) > 1 { + if ttlMs, err := strconv.Atoi(strings.TrimSpace(match[1])); err == nil && ttlMs > 0 { + cacheDirective.TimeToLiveMs = ttlMs + } + } + if cacheDirective.Provider == "" || cacheDirective.Location == "" || cacheDirective.TimeToLiveMs <= 0 { continue } - ttl := "" if len(call.args) == 2 { value, ok := parseQuotedLiteral(call.args[1]) - if !ok { - continue + if ok { + cacheDirective.TTL = strings.TrimSpace(value) } - ttl = strings.TrimSpace(value) } - result = append(result, &dqlshape.CacheDirective{Enabled: enabled, TTL: ttl}) + if cacheDirective.TTL == "" { + cacheDirective.TTL = strconv.Itoa(cacheDirective.TimeToLiveMs) + "ms" + } + result = append(result, cacheDirective) } return result } func parseMCPDirectives(input string) []*dqlshape.MCPDirective { calls := scanDollarCalls(input, mcpDirectiveName) + return parseMCPDirectiveCalls(calls) +} + +func parseMCPDirectiveCalls(calls []directiveCall) []*dqlshape.MCPDirective { result := make([]*dqlshape.MCPDirective, 0, len(calls)) for _, call := range calls { if len(call.args) < 1 || len(call.args) > 3 { @@ -235,6 +520,10 @@ func parseMCPDirectives(input string) []*dqlshape.MCPDirective { func parseRouteDirectives(input string) []*dqlshape.RouteDirective { calls := scanDollarCalls(input, routeDirectiveName) + return parseRouteDirectiveCalls(calls) +} + +func parseRouteDirectiveCalls(calls []directiveCall) []*dqlshape.RouteDirective { result := make([]*dqlshape.RouteDirective, 0, len(calls)) for _, call := range calls { if len(call.args) == 0 { @@ -272,6 +561,49 @@ func parseRouteDirectives(input string) []*dqlshape.RouteDirective { return result } +func parseReportDirectiveCalls(calls []directiveCall) []*dqlshape.ReportDirective { + result := make([]*dqlshape.ReportDirective, 0, len(calls)) + for _, call := range calls { + args := make([]string, 0, len(call.args)) + valid := true + for _, raw := range call.args { + value, ok := parseQuotedLiteral(raw) + if !ok { + valid = false + break + } + args = append(args, strings.TrimSpace(value)) + } + if !valid { + continue + } + directive := &dqlshape.ReportDirective{Enabled: true} + if len(args) > 0 { + directive.Input = strings.TrimSpace(args[0]) + } + if len(args) > 1 { + directive.Dimensions = strings.TrimSpace(args[1]) + } + if len(args) > 2 { + directive.Measures = strings.TrimSpace(args[2]) + } + if len(args) > 3 { + directive.Filters = strings.TrimSpace(args[3]) + } + if len(args) > 4 { + directive.OrderBy = strings.TrimSpace(args[4]) + } + if len(args) > 5 { + directive.Limit = strings.TrimSpace(args[5]) + } + if len(args) > 6 { + directive.Offset = strings.TrimSpace(args[6]) + } + result = append(result, directive) + } + return result +} + func normalizeHTTPMethods(input []string) ([]string, bool) { if len(input) == 0 { return nil, true @@ -308,6 +640,10 @@ func normalizeHTTPMethods(input []string) ([]string, bool) { func parseMarshalDirectives(input string) []string { calls := scanDollarCalls(input, marshalDirectiveName) + return parseMarshalDirectiveCalls(calls) +} + +func parseMarshalDirectiveCalls(calls []directiveCall) []string { result := make([]string, 0, len(calls)) for _, call := range calls { if len(call.args) != 2 { @@ -339,6 +675,10 @@ type unmarshalDirectiveValue struct { func parseUnmarshalDirectives(input string) []unmarshalDirectiveValue { calls := scanDollarCalls(input, unmarshalDirectiveName) + return parseUnmarshalDirectiveCalls(calls) +} + +func parseUnmarshalDirectiveCalls(calls []directiveCall) []unmarshalDirectiveValue { result := make([]unmarshalDirectiveValue, 0, len(calls)) for _, call := range calls { if len(call.args) != 2 { @@ -373,6 +713,10 @@ func parseUnmarshalDirectives(input string) []unmarshalDirectiveValue { func parseFormatDirectives(input string) []string { calls := scanDollarCalls(input, formatDirectiveName) + return parseFormatDirectiveCalls(calls) +} + +func parseFormatDirectiveCalls(calls []directiveCall) []string { result := make([]string, 0, len(calls)) for _, call := range calls { if len(call.args) != 1 { @@ -395,6 +739,10 @@ func parseFormatDirectives(input string) []string { func parseDateFormatDirectives(input string) []string { calls := scanDollarCalls(input, dateFormatDirectiveName) + return parseDateFormatDirectiveCalls(calls) +} + +func parseDateFormatDirectiveCalls(calls []directiveCall) []string { result := make([]string, 0, len(calls)) for _, call := range calls { if len(call.args) != 1 { @@ -413,6 +761,10 @@ func parseDateFormatDirectives(input string) []string { func parseCaseFormatDirectives(input string) []string { calls := scanDollarCalls(input, caseFormatDirectiveName) + return parseCaseFormatDirectiveCalls(calls) +} + +func parseCaseFormatDirectiveCalls(calls []directiveCall) []string { result := make([]string, 0, len(calls)) for _, call := range calls { if len(call.args) != 1 { @@ -433,3 +785,31 @@ func parseCaseFormatDirectives(input string) []string { } return result } + +func parseConstDirectives(input string) [][2]string { + calls := scanDollarCalls(input, constDirectiveName) + return parseConstDirectiveCalls(calls) +} + +func parseConstDirectiveCalls(calls []directiveCall) [][2]string { + var result [][2]string + for _, call := range calls { + if len(call.args) != 2 { + continue + } + name, ok := parseQuotedLiteral(call.args[0]) + if !ok { + continue + } + name = strings.TrimSpace(name) + if name == "" { + continue + } + value, ok := parseQuotedLiteral(call.args[1]) + if !ok { + continue + } + result = append(result, [2]string{name, strings.TrimSpace(value)}) + } + return result +} diff --git a/repository/shape/dql/sanitize/context_test.go b/repository/shape/dql/sanitize/context_test.go new file mode 100644 index 000000000..b19424763 --- /dev/null +++ b/repository/shape/dql/sanitize/context_test.go @@ -0,0 +1,118 @@ +package sanitize + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/velty" +) + +type criteriaContextMock struct{} + +type criteriaContextCollector struct { + Args []interface{} +} + +func (c *criteriaContextCollector) AppendBinding(value interface{}) string { + c.Args = append(c.Args, value) + return "?" +} + +type predicateContextMock struct{} + +func (p predicateContextMock) Builder() *predicateBuilderContextMock { + return &predicateBuilderContextMock{} +} + +func (p predicateContextMock) FilterGroup(group int, op string) string { + return fmt.Sprintf("P%d:%s", group, op) +} + +type predicateBuilderContextMock struct { + value string +} + +func (b *predicateBuilderContextMock) CombineOr(group string) *predicateBuilderContextMock { + b.value = group + return b +} + +func (b *predicateBuilderContextMock) Build(kind string) string { + switch kind { + case "AND": + return " AND (" + b.value + ") " + case "WHERE": + return " WHERE (" + b.value + ") " + default: + return "" + } +} + +type sqlContextMock struct{} + +func (s sqlContextMock) Eq(column string, value interface{}) string { + return fmt.Sprintf("%s = %v", column, value) +} + +type unsafeContextMock struct { + VendorID int +} + +func TestRenderVelty_WithShapeContext_DataDriven(t *testing.T) { + testCases := []struct { + name string + template string + expect string + args []interface{} + }{ + { + name: "criteria append binding", + template: "SELECT * FROM VENDOR t WHERE t.ID = $criteria.AppendBinding($Unsafe.VendorID)", + expect: "t.ID = ?", + args: []interface{}{101}, + }, + { + name: "predicate builder chain", + template: "SELECT * FROM PRODUCT t WHERE 1=1 ${predicate.Builder().CombineOr($predicate.FilterGroup(0, \"AND\")).Build(\"AND\")}", + expect: " AND (P0:AND) ", + }, + { + name: "sql helper", + template: "SELECT * FROM VENDOR t WHERE $sql.Eq(\"ID\", $Unsafe.VendorID)", + expect: "ID = 101", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + actual, args := renderVeltyWithShapeContext(t, testCase.template) + assert.Contains(t, actual, testCase.expect) + if len(testCase.args) > 0 { + assert.Equal(t, testCase.args, args) + } + }) + } +} + +func renderVeltyWithShapeContext(t *testing.T, template string) (string, []interface{}) { + t.Helper() + planner := velty.New() + require.NoError(t, planner.DefineVariable("criteria", &criteriaContextCollector{})) + require.NoError(t, planner.DefineVariable("predicate", predicateContextMock{})) + require.NoError(t, planner.DefineVariable("sql", sqlContextMock{})) + require.NoError(t, planner.DefineVariable("Unsafe", unsafeContextMock{})) + + exec, newState, err := planner.Compile([]byte(template)) + require.NoError(t, err) + + state := newState() + criteria := &criteriaContextCollector{} + require.NoError(t, state.SetValue("criteria", criteria)) + require.NoError(t, state.SetValue("predicate", predicateContextMock{})) + require.NoError(t, state.SetValue("sql", sqlContextMock{})) + require.NoError(t, state.SetValue("Unsafe", unsafeContextMock{VendorID: 101})) + require.NoError(t, exec.Exec(state)) + return state.Buffer.String(), criteria.Args +} diff --git a/repository/shape/dql/sanitize/policy.go b/repository/shape/dql/sanitize/policy.go index 4b6d6b5a4..a33b100c0 100644 --- a/repository/shape/dql/sanitize/policy.go +++ b/repository/shape/dql/sanitize/policy.go @@ -1,32 +1,91 @@ package sanitize -import "strings" +import ( + "strings" + + "github.com/viant/datly/view/keywords" + "github.com/viant/velty" +) type rewritePolicy struct { declared map[string]bool + foreach map[string]bool consts map[string]bool } -func newRewritePolicy(declared, consts map[string]bool) *rewritePolicy { +func newRewritePolicy(declared, foreach, consts map[string]bool) *rewritePolicy { return &rewritePolicy{ declared: declared, + foreach: foreach, consts: consts, } } -func (p *rewritePolicy) rewrite(raw string) string { +func (p *rewritePolicy) rewrite(raw string, kind velty.ExprContextKind) string { holder := holderName(raw) if holder == "" { return raw } + if keywords.Has(holder) { + return raw + } if strings.HasPrefix(raw, "$Unsafe.") || strings.HasPrefix(raw, "${Unsafe.") || strings.HasPrefix(raw, "$Has.") || strings.HasPrefix(raw, "${Has.") { return raw } if p.consts != nil && p.consts[holder] { return addUnsafePrefix(raw) } + if isControlOrFuncContext(kind) { + if p.declared != nil && p.declared[holder] { + return raw + } + if hasExplicitPrefix(raw) { + return raw + } + return addUnsafePrefix(raw) + } if p.declared != nil && p.declared[holder] { + if hasExplicitPrefix(raw) { + if p.foreach != nil && p.foreach[holder] { + return asPlaceholder(raw) + } + return asPlaceholder(addUnsafePrefix(raw)) + } return asPlaceholder(raw) } + if hasExplicitPrefix(raw) { + if p.foreach != nil && p.foreach[holder] { + return asPlaceholder(raw) + } + return asPlaceholder(addUnsafePrefix(raw)) + } return asPlaceholder(addUnsafePrefix(raw)) } + +func isControlOrFuncContext(kind velty.ExprContextKind) bool { + switch kind { + case velty.CtxFuncArg, + velty.CtxForEachCond, + velty.CtxIfCond, velty.CtxElseIfCond, + velty.CtxSetRHS, + velty.CtxForLoopInit, velty.CtxForLoopCond, velty.CtxForLoopPost: + return true + default: + return false + } +} + +func hasExplicitPrefix(raw string) bool { + name := strings.TrimSpace(raw) + if strings.HasPrefix(name, "${") && strings.HasSuffix(name, "}") { + name = "$" + name[2:len(name)-1] + } + if !strings.HasPrefix(name, "$") { + return false + } + name = strings.TrimPrefix(name, "$") + if idx := strings.Index(name, "("); idx != -1 { + return false + } + return strings.Index(name, ".") != -1 +} diff --git a/repository/shape/dql/sanitize/policy_test.go b/repository/shape/dql/sanitize/policy_test.go index c3ceb5f23..518a7575c 100644 --- a/repository/shape/dql/sanitize/policy_test.go +++ b/repository/shape/dql/sanitize/policy_test.go @@ -1,13 +1,19 @@ package sanitize -import "testing" +import ( + "testing" + + "github.com/viant/velty" +) func TestRewritePolicy_Rewrite(t *testing.T) { testCases := []struct { name string raw string declared map[string]bool + foreach map[string]bool consts map[string]bool + kind velty.ExprContextKind expect string }{ { @@ -26,6 +32,32 @@ func TestRewritePolicy_Rewrite(t *testing.T) { declared: map[string]bool{"x": true}, expect: "$criteria.AppendBinding($x)", }, + { + name: "declared foreach variable in body uses placeholder", + raw: "$rec.ID", + declared: map[string]bool{"rec": true}, + foreach: map[string]bool{"rec": true}, + kind: velty.CtxForEachBody, + expect: "$criteria.AppendBinding($rec.ID)", + }, + { + name: "declared dotted parameter uses unsafe placeholder in append context", + raw: "$Jwt.UserID", + declared: map[string]bool{"Jwt": true}, + expect: "$criteria.AppendBinding($Unsafe.Jwt.UserID)", + }, + { + name: "function arg gets unsafe prefix", + raw: "$VendorID", + kind: velty.CtxFuncArg, + expect: "$Unsafe.VendorID", + }, + { + name: "prefixed function arg is preserved", + raw: "$sql.Eq", + kind: velty.CtxFuncArg, + expect: "$sql.Eq", + }, { name: "const selector keeps raw unsafe path", raw: "$ConstID", @@ -40,8 +72,8 @@ func TestRewritePolicy_Rewrite(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - policy := newRewritePolicy(testCase.declared, testCase.consts) - if actual := policy.rewrite(testCase.raw); actual != testCase.expect { + policy := newRewritePolicy(testCase.declared, testCase.foreach, testCase.consts) + if actual := policy.rewrite(testCase.raw, testCase.kind); actual != testCase.expect { t.Fatalf("unexpected rewrite: %s", actual) } }) diff --git a/repository/shape/dql/sanitize/sanitizer.go b/repository/shape/dql/sanitize/sanitizer.go index ec414620d..83e94968e 100644 --- a/repository/shape/dql/sanitize/sanitizer.go +++ b/repository/shape/dql/sanitize/sanitizer.go @@ -11,6 +11,7 @@ import ( type Options struct { Declared map[string]bool + Foreach map[string]bool Consts map[string]bool } @@ -29,6 +30,21 @@ func Declared(input string) map[string]bool { ret[name] = true } } + for _, name := range scanForeachDeclaredHolders(input) { + if name != "" { + ret[name] = true + } + } + return ret +} + +func ForeachDeclared(input string) map[string]bool { + ret := map[string]bool{} + for _, name := range scanForeachDeclaredHolders(input) { + if name != "" { + ret[name] = true + } + } return ret } @@ -129,6 +145,56 @@ func isSanitizeIdentifierPart(ch byte) bool { return isSanitizeIdentifierStart(ch) || (ch >= '0' && ch <= '9') } +func scanForeachDeclaredHolders(input string) []string { + result := make([]string, 0) + lower := strings.ToLower(input) + for i := 0; i < len(input); i++ { + if input[i] != '#' { + continue + } + if !strings.HasPrefix(lower[i:], "#foreach") { + continue + } + j := i + len("#foreach") + for j < len(input) && (input[j] == ' ' || input[j] == '\t' || input[j] == '\r' || input[j] == '\n') { + j++ + } + if j >= len(input) || input[j] != '(' { + continue + } + body, end, ok := readSetDirectiveBody(input, j) + if !ok { + continue + } + if name, ok := parseForeachHolder(body); ok { + result = append(result, name) + } + i = end + } + return result +} + +func parseForeachHolder(body string) (string, bool) { + text := strings.TrimSpace(body) + if text == "" || text[0] != '$' || len(text) < 2 { + return "", false + } + text = text[1:] + end := 0 + for end < len(text) && isSanitizeIdentifierPart(text[end]) { + end++ + } + if end == 0 { + return "", false + } + name := text[:end] + rest := strings.TrimSpace(text[end:]) + if !strings.HasPrefix(strings.ToLower(rest), "in ") { + return "", false + } + return name, true +} + func SQL(input string, opts Options) string { return Rewrite(input, opts).SQL } @@ -140,8 +206,9 @@ func Rewrite(input string, opts Options) RewriteResult { adjuster := &bindingAdjuster{ source: []byte(input), declared: opts.Declared, + foreach: opts.Foreach, consts: opts.Consts, - policy: newRewritePolicy(opts.Declared, opts.Consts), + policy: newRewritePolicy(opts.Declared, opts.Foreach, opts.Consts), } out, err := velty.TransformTemplate([]byte(input), adjuster) if err != nil { @@ -158,12 +225,18 @@ func Rewrite(input string, opts Options) RewriteResult { type bindingAdjuster struct { source []byte declared map[string]bool + foreach map[string]bool consts map[string]bool policy *rewritePolicy patches []velty.Patch } func (b *bindingAdjuster) Adjust(node ast.Node, ctx *velty.ParserContext) (velty.Action, error) { + if call, ok := node.(*aexpr.Call); ok { + b.rewriteCallNodeArgs(call, ctx) + return velty.Keep(), nil + } + sel, ok := node.(*aexpr.Select) if !ok { return velty.Keep(), nil @@ -179,7 +252,7 @@ func (b *bindingAdjuster) Adjust(node ast.Node, ctx *velty.ParserContext) (velty return velty.Keep(), nil } raw := string(b.source[span.Start : span.End+1]) - replacement := b.rewrite(raw) + replacement := b.rewrite(raw, ctx.CurrentExprContext().Kind) if replacement == raw { return velty.Keep(), nil } @@ -190,6 +263,50 @@ func (b *bindingAdjuster) Adjust(node ast.Node, ctx *velty.ParserContext) (velty return velty.PatchSpan(span, []byte(replacement)), nil } +func (b *bindingAdjuster) rewriteCallNodeArgs(call *aexpr.Call, ctx *velty.ParserContext) { + selectors := make([]*aexpr.Select, 0, 4) + for _, arg := range call.Args { + b.collectSelectors(arg, &selectors) + } + for _, sel := range selectors { + span, ok := ctx.GetSpan(sel) + if !ok { + continue + } + if b.inSetDirective(span.Start) { + continue + } + raw := string(b.source[span.Start : span.End+1]) + replacement := b.rewrite(raw, velty.CtxFuncArg) + if replacement == raw { + continue + } + b.patches = append(b.patches, velty.Patch{ + Span: span, + Replacement: []byte(replacement), + }) + } +} + +func (b *bindingAdjuster) collectSelectors(expr ast.Expression, selectors *[]*aexpr.Select) { + switch actual := expr.(type) { + case *aexpr.Select: + *selectors = append(*selectors, actual) + case *aexpr.Call: + b.collectSelectors(actual.X, selectors) + for _, arg := range actual.Args { + b.collectSelectors(arg, selectors) + } + case *aexpr.Binary: + b.collectSelectors(actual.X, selectors) + b.collectSelectors(actual.Y, selectors) + case *aexpr.Unary: + b.collectSelectors(actual.X, selectors) + case *aexpr.Parentheses: + b.collectSelectors(actual.P, selectors) + } +} + func (b *bindingAdjuster) inSetDirective(pos int) bool { if pos <= 0 || pos > len(b.source) { return false @@ -206,11 +323,166 @@ func (b *bindingAdjuster) inSetDirective(pos int) bool { return strings.Count(segment, "(") > strings.Count(segment, ")") } -func (b *bindingAdjuster) rewrite(raw string) string { +func (b *bindingAdjuster) rewrite(raw string, kind velty.ExprContextKind) string { if b.policy == nil { - b.policy = newRewritePolicy(b.declared, b.consts) + b.policy = newRewritePolicy(b.declared, b.foreach, b.consts) + } + if rewritten, ok := b.rewriteCallArgsInSelector(raw); ok { + return rewritten + } + return b.policy.rewrite(raw, kind) +} + +func (b *bindingAdjuster) rewriteCallArgsInSelector(raw string) (string, bool) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" || !strings.HasPrefix(trimmed, "$") || !strings.Contains(trimmed, "(") { + return "", false + } + + hasBraces := strings.HasPrefix(trimmed, "${") && strings.HasSuffix(trimmed, "}") + expr := trimmed + if hasBraces { + expr = "$" + trimmed[2:len(trimmed)-1] + } + + open := strings.Index(expr, "(") + if open <= 0 || !strings.HasPrefix(expr, "$") { + return "", false + } + closeIdx, ok := matchingParen(expr, open) + if !ok || closeIdx != len(expr)-1 { + return "", false + } + + args := expr[open+1 : closeIdx] + rewrittenArgs, changed := b.rewriteCallArgs(args) + if !changed { + return "", false + } + rewritten := expr[:open+1] + rewrittenArgs + expr[closeIdx:] + if hasBraces { + rewritten = "${" + rewritten[1:] + "}" + } + return rewritten, true +} + +func (b *bindingAdjuster) rewriteCallArgs(args string) (string, bool) { + parts := splitArgs(args) + if len(parts) == 0 { + return args, false + } + changed := false + for i := range parts { + part := parts[i] + lead, core, tail := trimArgWhitespace(part) + if core == "" { + continue + } + rewrittenCore := core + if strings.HasPrefix(core, "$") { + if nested, ok := b.rewriteCallArgsInSelector(core); ok { + rewrittenCore = nested + } else { + rewrittenCore = b.policy.rewrite(core, velty.CtxFuncArg) + } + } + if rewrittenCore != core { + changed = true + parts[i] = lead + rewrittenCore + tail + } + } + if !changed { + return args, false + } + return strings.Join(parts, ","), true +} + +func splitArgs(input string) []string { + if input == "" { + return nil + } + result := make([]string, 0, 4) + start := 0 + depth := 0 + quote := byte(0) + for i := 0; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + if depth > 0 { + depth-- + } + continue + } + if ch == ',' && depth == 0 { + result = append(result, input[start:i]) + start = i + 1 + } + } + result = append(result, input[start:]) + return result +} + +func trimArgWhitespace(input string) (string, string, string) { + start := 0 + for start < len(input) && (input[start] == ' ' || input[start] == '\t' || input[start] == '\n' || input[start] == '\r') { + start++ + } + end := len(input) + for end > start && (input[end-1] == ' ' || input[end-1] == '\t' || input[end-1] == '\n' || input[end-1] == '\r') { + end-- + } + return input[:start], input[start:end], input[end:] +} + +func matchingParen(input string, open int) (int, bool) { + depth := 0 + quote := byte(0) + for i := open; i < len(input); i++ { + ch := input[i] + if quote != 0 { + if ch == '\\' && i+1 < len(input) { + i++ + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '\'' || ch == '"' { + quote = ch + continue + } + if ch == '(' { + depth++ + continue + } + if ch == ')' { + depth-- + if depth == 0 { + return i, true + } + } } - return b.policy.rewrite(raw) + return -1, false } func holderName(raw string) string { diff --git a/repository/shape/dql/sanitize/sanitizer_test.go b/repository/shape/dql/sanitize/sanitizer_test.go index 326bd3b1d..308bb5068 100644 --- a/repository/shape/dql/sanitize/sanitizer_test.go +++ b/repository/shape/dql/sanitize/sanitizer_test.go @@ -39,6 +39,20 @@ func TestSQL_ParityWithLegacySanitizer(t *testing.T) { &inference.Parameter{Parameter: vstate.Parameter{Name: "ConstId", In: vstate.NewConstLocation("ConstId")}}, }, }, + { + name: "exec foreach and logger hooks", + sql: "#foreach($rec in $Unsafe.Records)\n" + + "#if($rec.IS_AUTH == 0)\n" + + " $logger.Fatal(\"Unauthorized access to product: %v\", $rec.ID)\n" + + "#end\n" + + "UPDATE PRODUCT SET STATUS = $Status WHERE ID = $rec.ID\n" + + "#end", + }, + { + name: "predicate and sql hooks", + sql: "SELECT * FROM PRODUCT t WHERE 1=1 ${predicate.Builder().CombineOr($predicate.FilterGroup(0, \"AND\")).Build(\"AND\")} " + + "AND $sql.Eq(\"ID\", $VendorID)", + }, } for _, testCase := range testCases { @@ -50,6 +64,7 @@ func TestSQL_ParityWithLegacySanitizer(t *testing.T) { actual := SQL(testCase.sql, Options{ Declared: tpl.Declared, + Foreach: ForeachDeclared(testCase.sql), Consts: constNames(state), }) assert.Equal(t, expected, actual) @@ -82,6 +97,20 @@ func TestSQL_ParityWithLegacySanitizer_RuntimeExpansion(t *testing.T) { &inference.Parameter{Parameter: vstate.Parameter{Name: "ConstId", In: vstate.NewConstLocation("ConstId")}}, }, }, + { + name: "exec foreach and logger hooks", + sql: "#foreach($rec in $Unsafe.Records)\n" + + "#if($rec.IS_AUTH == 0)\n" + + " $logger.Fatal(\"Unauthorized access to product: %v\", $rec.ID)\n" + + "#end\n" + + "UPDATE PRODUCT SET STATUS = $Status WHERE ID = $rec.ID\n" + + "#end", + }, + { + name: "predicate and sql hooks", + sql: "SELECT * FROM PRODUCT t WHERE 1=1 ${predicate.Builder().CombineOr($predicate.FilterGroup(0, \"AND\")).Build(\"AND\")} " + + "AND $sql.Eq(\"ID\", $VendorID)", + }, } for _, testCase := range testCases { @@ -93,6 +122,7 @@ func TestSQL_ParityWithLegacySanitizer_RuntimeExpansion(t *testing.T) { shapeSQL := SQL(testCase.sql, Options{ Declared: tpl.Declared, + Foreach: ForeachDeclared(testCase.sql), Consts: constNames(state), }) require.Equal(t, legacySQL, shapeSQL) @@ -156,6 +186,11 @@ func TestDeclared_ParameterDeclarationStyle(t *testing.T) { assert.True(t, declared["Jwt"]) } +func TestDeclared_ForeachVariable(t *testing.T) { + declared := Declared("#foreach($rec in $Unsafe.Records)\nUPDATE t SET v = $rec.ID\n#end") + assert.True(t, declared["rec"]) +} + func TestDeclaredListener_OnEventBranches(t *testing.T) { declared := map[string]bool{} l := &declaredListener{declared: declared} @@ -218,9 +253,57 @@ func (c criteriaMock) AppendBinding(value interface{}) string { } type unsafeMock struct { - Id int - Name string - ConstId int + Id int + Name string + ConstId int + VendorID int + Status int + Records []recordMock +} + +type recordMock struct { + ID int + IS_AUTH int +} + +type sqlMock struct{} + +func (s sqlMock) Eq(column string, value interface{}) string { + return fmt.Sprintf("%s = %v", column, value) +} + +type predicateMock struct{} + +func (p predicateMock) Builder() *predicateBuilderMock { + return &predicateBuilderMock{} +} + +func (p predicateMock) FilterGroup(group int, op string) string { + return fmt.Sprintf("P%d:%s", group, op) +} + +type predicateBuilderMock struct { + value string +} + +func (b *predicateBuilderMock) CombineOr(group string) *predicateBuilderMock { + b.value = group + return b +} + +func (b *predicateBuilderMock) Build(kind string) string { + switch kind { + case "AND": + return " AND (" + b.value + ") " + default: + return "" + } +} + +type loggerMock struct{} + +func (l loggerMock) Fatal(_ string, _ ...interface{}) string { + return "" } func renderVeltySQL(t *testing.T, template string) string { @@ -228,18 +311,35 @@ func renderVeltySQL(t *testing.T, template string) string { planner := velty.New() require.NoError(t, planner.DefineVariable("criteria", criteriaMock{})) require.NoError(t, planner.DefineVariable("Unsafe", unsafeMock{})) + require.NoError(t, planner.DefineVariable("sql", sqlMock{})) + require.NoError(t, planner.DefineVariable("predicate", predicateMock{})) + require.NoError(t, planner.DefineVariable("logger", loggerMock{})) require.NoError(t, planner.DefineVariable("Id", 0)) require.NoError(t, planner.DefineVariable("Name", "")) require.NoError(t, planner.DefineVariable("ConstId", 0)) + require.NoError(t, planner.DefineVariable("VendorID", 0)) + require.NoError(t, planner.DefineVariable("Status", 0)) exec, newState, err := planner.Compile([]byte(template)) require.NoError(t, err) state := newState() require.NoError(t, state.SetValue("criteria", criteriaMock{})) - require.NoError(t, state.SetValue("Unsafe", unsafeMock{Id: 10, Name: "ann", ConstId: 77})) + require.NoError(t, state.SetValue("Unsafe", unsafeMock{ + Id: 10, + Name: "ann", + ConstId: 77, + VendorID: 101, + Status: 1, + Records: []recordMock{{ID: 10, IS_AUTH: 1}}, + })) + require.NoError(t, state.SetValue("sql", sqlMock{})) + require.NoError(t, state.SetValue("predicate", predicateMock{})) + require.NoError(t, state.SetValue("logger", loggerMock{})) require.NoError(t, state.SetValue("Id", 10)) require.NoError(t, state.SetValue("Name", "ann")) require.NoError(t, state.SetValue("ConstId", 77)) + require.NoError(t, state.SetValue("VendorID", 101)) + require.NoError(t, state.SetValue("Status", 1)) require.NoError(t, exec.Exec(state)) return state.Buffer.String() } diff --git a/repository/shape/dql/scan/scanner.go b/repository/shape/dql/scan/scanner.go index b7ecb2d69..8dac0e291 100644 --- a/repository/shape/dql/scan/scanner.go +++ b/repository/shape/dql/scan/scanner.go @@ -1,24 +1,16 @@ package scan import ( - "context" "fmt" - "path/filepath" "reflect" "strings" "time" _ "github.com/go-sql-driver/mysql" - "github.com/viant/afs" - "github.com/viant/afs/file" - "github.com/viant/afs/url" - "github.com/viant/datly/cmd/options" - "github.com/viant/datly/internal/translator" "github.com/viant/datly/repository/shape/dql/decl" "github.com/viant/datly/repository/shape/dql/ir" "github.com/viant/datly/repository/shape/dql/parse" dqlplan "github.com/viant/datly/repository/shape/dql/plan" - "github.com/viant/datly/repository/shape/dql/sanitize" dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/repository/shape/typectx/source" @@ -49,87 +41,16 @@ type Result struct { } // Scanner translates DQL to Datly route YAML in-memory. -type Scanner struct { - fs afs.Service -} +type Scanner struct{} func New() *Scanner { - return &Scanner{fs: afs.New()} -} - -func (s *Scanner) Scan(ctx context.Context, req *Request) (result *Result, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("dql scan panic: %v", r) - result = nil - } - }() - if req == nil || req.DQLURL == "" { - return nil, fmt.Errorf("dql scan: DQLURL was empty") - } - sourceURL := req.DQLURL - project := inferProject(req.DQLURL) - translate := &options.Translate{} - translate.Rule.Project = project - translate.Rule.Source = []string{sourceURL} - translate.Rule.ModulePrefix = req.ModulePrefix - translate.Repository.RepositoryURL = req.Repository - translate.Repository.APIPrefix = req.APIPrefix - if len(req.Connectors) > 0 { - translate.Repository.Connectors = append(translate.Repository.Connectors, req.Connectors...) - } - if req.ConfigURL != "" { - translate.Repository.Configs.Append(req.ConfigURL) - } - var initErr error - if initErr = translate.Init(ctx); initErr != nil { - return nil, initErr - } - if req.ConfigURL == "" { - // Force in-memory translator config to avoid stale absolute paths from discovered config.json. - translate.Repository.Configs = nil - } - if translate.Rule.ModulePrefix == "" { - translate.Rule.ModulePrefix = "platform" - } - - svc := translator.New(translator.NewConfig(&translate.Repository), s.fs) - if initErr := svc.Init(ctx); initErr != nil { - return nil, initErr - } - if initErr := svc.InitSignature(ctx, &translate.Rule); initErr != nil { - return nil, initErr - } - dsql, loadErr := translate.Rule.LoadSource(ctx, s.fs, translate.Rule.SourceURL()) - if loadErr != nil { - return nil, loadErr - } - translate.Rule.NormalizeComponent(&dsql) - dsql = sanitize.SQL(dsql, sanitize.Options{Declared: sanitize.Declared(dsql)}) - top := &options.Options{Translate: translate} - if initErr = svc.Translate(ctx, &translate.Rule, dsql, top); initErr != nil { - return nil, initErr - } - ruleName := svc.Repository.RuleName(&translate.Rule) - targetSuffix := "/" + ruleName + ".yaml" - for _, item := range svc.Repository.Files { - if !strings.HasSuffix(item.URL, targetSuffix) { - continue - } - if strings.Contains(item.URL, "/.meta/") { - continue - } - return s.result(ruleName, []byte(item.Content), dsql, req) - } - for _, item := range svc.Repository.Files { - if strings.HasSuffix(item.URL, targetSuffix) { - return s.result(ruleName, []byte(item.Content), dsql, req) - } - } - return nil, fmt.Errorf("dql scan: generated YAML not found for %s", ruleName) + return &Scanner{} } -func (s *Scanner) result(ruleName string, routeYAML []byte, dql string, req *Request) (*Result, error) { +// Result builds a scan Result from route YAML bytes. Exported so that bridge +// packages (e.g. testutil/shapeparity) can call it after running the legacy +// translator pipeline externally. +func (s *Scanner) Result(ruleName string, routeYAML []byte, dql string, req *Request) (*Result, error) { if err := dqlplan.ValidateRelations(routeYAML); err != nil { return nil, fmt.Errorf("dql scan relation validation failed (%s): %w", ruleName, err) } @@ -421,11 +342,3 @@ func validateResolutionPolicy(resolution typectx.Resolution, policy provenancePo } return "" } - -func inferProject(dqlURL string) string { - base, _ := url.Split(dqlURL, file.Scheme) - if idx := strings.Index(base, "/dql/"); idx != -1 { - return filepath.Clean(base[:idx]) - } - return filepath.Clean(base) -} diff --git a/repository/shape/dql/scan/scanner_test.go b/repository/shape/dql/scan/scanner_test.go index 50e90999a..9374e9a7c 100644 --- a/repository/shape/dql/scan/scanner_test.go +++ b/repository/shape/dql/scan/scanner_test.go @@ -31,7 +31,7 @@ Resource: Template: Source: SELECT c.ID FROM T2 c `) - _, err := s.result("x", invalidYAML, "", nil) + _, err := s.Result("x", invalidYAML, "", nil) require.Error(t, err) require.Contains(t, err.Error(), "dql scan relation validation failed") require.Contains(t, err.Error(), "column=\"MISSING_COL\"") @@ -54,7 +54,7 @@ Resource: Template: Source: SELECT r.ID FROM ROOT r `) - result, err := s.result("sample", validYAML, "", nil) + result, err := s.Result("sample", validYAML, "", nil) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Shape) @@ -83,7 +83,7 @@ Resource: #package('mdp/performance') #import('perf', 'github.com/acme/mdp/performance') SELECT r.ID FROM ROOT r` - result, err := s.result("sample", validYAML, dql, nil) + result, err := s.Result("sample", validYAML, dql, nil) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Shape) @@ -117,7 +117,7 @@ Resource: dql := ` #package('github.com/acme/mdp/performance') SELECT cast(r.ID as 'Order') FROM ROOT r` - result, err := s.result("sample", validYAML, dql, nil) + result, err := s.Result("sample", validYAML, dql, nil) require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.Shape) @@ -155,7 +155,7 @@ Resource: #package('github.com/acme/mdp/performance') SELECT cast(r.ID as 'Order') FROM ROOT r` strict := true - _, err := s.result("sample", validYAML, dql, &Request{ + _, err := s.Result("sample", validYAML, dql, &Request{ Repository: filepath.Clean(t.TempDir()), StrictProvenance: &strict, }) diff --git a/repository/shape/dql/shape/model.go b/repository/shape/dql/shape/model.go index e975cfca3..5b35af53a 100644 --- a/repository/shape/dql/shape/model.go +++ b/repository/shape/dql/shape/model.go @@ -42,9 +42,18 @@ type Diagnostic struct { type Directives struct { Meta string DefaultConnector string + TemplateType string + Dest string + InputDest string + OutputDest string + RouterDest string + InputType string + OutputType string Cache *CacheDirective MCP *MCPDirective Route *RouteDirective + Report *ReportDirective + Const map[string]string JSONMarshalType string JSONUnmarshalType string XMLUnmarshalType string @@ -54,8 +63,12 @@ type Directives struct { } type CacheDirective struct { - Enabled bool - TTL string + Enabled bool + TTL string + Name string + Provider string + Location string + TimeToLiveMs int } type MCPDirective struct { @@ -69,6 +82,17 @@ type RouteDirective struct { Methods []string } +type ReportDirective struct { + Enabled bool + Input string + Dimensions string + Measures string + Filters string + OrderBy string + Limit string + Offset string +} + type Route struct { Name string URI string diff --git a/repository/shape/dql_engine_test.go b/repository/shape/dql_engine_test.go index a7a40fff5..a19db4b8d 100644 --- a/repository/shape/dql_engine_test.go +++ b/repository/shape/dql_engine_test.go @@ -2,13 +2,24 @@ package shape_test import ( "context" + "fmt" + "os" + "path/filepath" + "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + marshalconfig "github.com/viant/datly/gateway/router/marshal/config" + marshaljson "github.com/viant/datly/gateway/router/marshal/json" shape "github.com/viant/datly/repository/shape" shapeCompile "github.com/viant/datly/repository/shape/compile" shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view" + "github.com/viant/datly/view/extension" + "github.com/viant/tagly/format/text" + "github.com/viant/xreflect" ) func TestEngine_LoadDQLViews(t *testing.T) { @@ -24,6 +35,20 @@ func TestEngine_LoadDQLViews(t *testing.T) { assert.Equal(t, "t", artifacts.Views[0].Name) } +func TestEngine_LoadDQLResource(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + artifacts, err := engine.LoadDQLResource(context.Background(), "SELECT id FROM ORDERS t") + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + require.Len(t, artifacts.Resource.Views, 1) + assert.Equal(t, "t", artifacts.Resource.Views[0].Name) +} + func TestEngine_LoadDQLComponent(t *testing.T) { engine := shape.New( shape.WithCompiler(shapeCompile.New()), @@ -64,3 +89,343 @@ SELECT id FROM ORDERS t` require.NotEmpty(t, component.Predicates["o"]) assert.Equal(t, "ByID", component.Predicates["o"][0].Name) } + +func TestEngine_LoadDQLComponent_PreservesExplicitOutputViewOneCardinality(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/shape/dev/auth/user-acl"), + ) + dql := ` +#define($_ = $Data(output/view).Cardinality('One').Embed()) +SELECT 1 AS UserID` + artifact, err := engine.LoadDQLComponent(context.Background(), dql) + require.NoError(t, err) + require.NotNil(t, artifact) + + component, ok := shapeLoad.ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.Output, 1) + require.NotNil(t, component.Output[0].Schema) + assert.Equal(t, "One", string(component.Output[0].Schema.Cardinality)) +} + +type metaFormatOutput struct { + Meta *metaFormatMeta `json:"meta,omitempty"` + Data []metaFormatData `json:"data,omitempty"` + Status string `json:"status,omitempty"` +} + +type metaFormatMeta struct { + PageCnt *int `json:"pageCnt,omitempty"` + Cnt int `json:"cnt,omitempty"` +} + +type metaFormatData struct { + Id int `json:"id,omitempty"` + Name *string `json:"name,omitempty"` + AccountId *int `json:"accountId,omitempty"` + Products []*metaFormatProduct `json:"products,omitempty"` + ProductsMeta *metaFormatProductsMeta `json:"productsMeta,omitempty"` +} + +type metaFormatProduct struct { + Id int `json:"id,omitempty"` + Name *string `json:"name,omitempty"` + VendorId *int `json:"vendorId,omitempty"` +} + +type metaFormatProductsMeta struct { + VendorId *int `json:"vendorId,omitempty"` + PageCnt *int `json:"pageCnt,omitempty"` + TotalProducts int `json:"totalProducts,omitempty"` +} + +func TestMetaFormatLiveLikeOutput_Marshal(t *testing.T) { + name := "Acme" + id := 1 + pageCnt := 2 + output := &metaFormatOutput{ + Meta: &metaFormatMeta{PageCnt: &pageCnt, Cnt: 3}, + Data: []metaFormatData{ + { + Id: 1, + Name: &name, + AccountId: &id, + Products: []*metaFormatProduct{ + {Id: 10, Name: &name, VendorId: &id}, + }, + ProductsMeta: &metaFormatProductsMeta{VendorId: &id, PageCnt: &pageCnt, TotalProducts: 1}, + }, + }, + Status: "ok", + } + marshaller := marshaljson.New(&marshalconfig.IOConfig{CaseFormat: text.CaseFormatLowerCamel}) + _, err := marshaller.Marshal(output) + require.NoError(t, err) +} + +func TestDQLCompileLoad_MetaFormatPreservesSummariesWithoutLinkedTypes(t *testing.T) { + dqlPath := filepath.Join("..", "..", "e2e", "v1", "dql", "dev", "vendorsrv", "meta_format.dql") + dqlBytes, err := os.ReadFile(dqlPath) + require.NoError(t, err) + + source := &shape.Source{ + Name: "meta_format", + Path: dqlPath, + DQL: string(dqlBytes), + } + planResult, err := shapeCompile.New().Compile( + context.Background(), + source, + shape.WithLinkedTypes(false), + shape.WithTypeContextPackageDir(filepath.Join("e2e", "v1", "shape", "dev", "vendorsvc", "multi_summary")), + shape.WithTypeContextPackageName("multi_summary"), + ) + require.NoError(t, err) + registry := source.EnsureTypeRegistry() + require.NotNil(t, registry) + if lookup := registry.Lookup("ProductsView"); lookup != nil { + fmt.Printf("registry ProductsView: %T %v\n", lookup.Type, lookup.Type) + } else { + fmt.Printf("registry ProductsView: \n") + } + if lookup := registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary.ProductsView"); lookup != nil { + fmt.Printf("registry fq ProductsView: %T %v\n", lookup.Type, lookup.Type) + } else { + fmt.Printf("registry fq ProductsView: \n") + } + planned, ok := shapePlan.ResultFrom(planResult) + require.True(t, ok) + foundPlannedProductsType := false + for _, item := range planned.Types { + if item != nil && item.Name == "ProductsView" { + foundPlannedProductsType = true + assert.NotEmpty(t, item.DataType) + } + } + var plannedProductsView *shapePlan.View + for _, item := range planned.Views { + if item != nil && item.Name == "products" { + plannedProductsView = item + break + } + } + require.NotNil(t, plannedProductsView) + require.NotNil(t, plannedProductsView.FieldType) + t.Logf("planned products fieldType=%v elementType=%v", plannedProductsView.FieldType, plannedProductsView.ElementType) + assert.False(t, foundPlannedProductsType) + registered := registry.Lookup("github.com/viant/datly/e2e/v1/shape/dev/vendorsvc/multi_summary.ProductsMetaView") + require.NotNil(t, registered) + require.NotNil(t, registered.Type) + registeredType := registered.Type + if registeredType.Kind() == reflect.Ptr { + registeredType = registeredType.Elem() + } + field, ok := registeredType.FieldByName("VendorId") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) + + resourceArtifacts, err := shapeLoad.New().LoadResource(context.Background(), planResult, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + require.NotNil(t, resourceArtifacts) + require.NotNil(t, resourceArtifacts.Resource) + + index := resourceArtifacts.Resource.Views.Index() + root, err := index.Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + t.Logf("root view: name=%s ref=%s schema=%v with=%d", root.Name, root.Ref, root.Schema != nil, len(root.With)) + for i, rel := range root.With { + if rel == nil || rel.Of == nil { + t.Logf("root relation[%d]: nil", i) + continue + } + relSchemaType := "" + if rel.Of.View.Schema != nil && rel.Of.View.Schema.Type() != nil { + relSchemaType = rel.Of.View.Schema.Type().String() + } + t.Logf("root relation[%d]: holder=%s name=%s ref=%s schema=%v schemaType=%s summary=%v", i, rel.Holder, rel.Of.View.Name, rel.Of.View.Ref, rel.Of.View.Schema != nil, relSchemaType, rel.Of.View.Template != nil && rel.Of.View.Template.Summary != nil) + } + require.NotNil(t, root.Template) + require.NotNil(t, root.Template.Summary) + require.NotNil(t, root.Template.Summary.Schema) + assert.Equal(t, "MetaView", root.Template.Summary.Schema.Name) + + child, err := index.Lookup("products") + require.NoError(t, err) + require.NotNil(t, child) + require.NotNil(t, child.Template) + require.NotNil(t, child.Template.Summary) + require.NotNil(t, child.Template.Summary.Schema) + assert.Equal(t, "ProductsMetaView", child.Template.Summary.Schema.Name) + childViewSummaryType := child.Template.Summary.Schema.Type() + require.NotNil(t, childViewSummaryType) + if childViewSummaryType.Kind() == reflect.Ptr { + childViewSummaryType = childViewSummaryType.Elem() + } + field, ok = childViewSummaryType.FieldByName("VendorId") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) + require.NotEmpty(t, root.With) + require.NotNil(t, root.With[0].Of) + require.NotNil(t, root.With[0].Of.View.Template) + require.NotNil(t, root.With[0].Of.View.Template.Summary) + childSummaryType := root.With[0].Of.View.Template.Summary.Schema.Type() + require.NotNil(t, childSummaryType) + if childSummaryType.Kind() == reflect.Ptr { + childSummaryType = childSummaryType.Elem() + } + field, ok = childSummaryType.FieldByName("VendorId") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) + + componentArtifact, err := shapeLoad.New().LoadComponent(context.Background(), planResult, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + component, ok := shapeLoad.ComponentFrom(componentArtifact) + require.True(t, ok) + foundProductsType := false + for _, item := range componentArtifact.Resource.Types { + if item != nil { + t.Logf("resource type: name=%s dataType=%s fields=%d package=%s module=%s", item.Name, item.DataType, len(item.Fields), item.Package, item.ModulePath) + } + if item == nil || item.Name != "ProductsView" { + continue + } + foundProductsType = true + require.NotEmpty(t, item.Fields) + break + } + require.True(t, foundProductsType) + typeRegistry, err := initTypeRegistryForResource(componentArtifact.Resource) + require.NoError(t, err) + + foundSummary := false + for _, param := range component.Output { + if param != nil && param.In != nil && param.In.Name == "summary" { + foundSummary = true + require.NotNil(t, param.Schema) + assert.Equal(t, "MetaView", param.Schema.Name) + } + } + assert.True(t, foundSummary) + + outputType, err := component.OutputReflectType("", typeRegistry.Lookup) + require.NoError(t, err) + require.NotNil(t, outputType) + + output := reflect.New(outputType).Elem() + dataField := output.FieldByName("Data") + require.True(t, dataField.IsValid()) + require.Equal(t, reflect.Slice, dataField.Kind()) + + rowType := dataField.Type().Elem() + fmt.Printf("output Data type: %T %v\n", dataField.Interface(), dataField.Type()) + rowValue := reflect.New(rowType) + if rowType.Kind() == reflect.Ptr { + rowValue = reflect.New(rowType.Elem()) + } + row := rowValue.Elem() + row.FieldByName("Id").SetInt(1) + + productsField := row.FieldByName("Products") + require.True(t, productsField.IsValid()) + productType := productsField.Type().Elem() + product := reflect.New(productType) + if productType.Kind() == reflect.Ptr { + product = reflect.New(productType.Elem()) + } + product.Elem().FieldByName("Id").SetInt(10) + if productType.Kind() == reflect.Ptr { + productsField.Set(reflect.Append(productsField, product)) + } else { + productsField.Set(reflect.Append(productsField, product.Elem())) + } + + data := reflect.MakeSlice(dataField.Type(), 0, 1) + if rowType.Kind() == reflect.Ptr { + data = reflect.Append(data, rowValue) + } else { + data = reflect.Append(data, row) + } + dataField.Set(data) + + marshaller := marshaljson.New(&marshalconfig.IOConfig{CaseFormat: text.CaseFormatLowerCamel}) + _, err = marshaller.Marshal(output.Addr().Interface()) + require.NoError(t, err) +} + +func TestDQLCompileLoad_DistrictPaginationInheritsNestedRelationTypeContextPackages(t *testing.T) { + dqlPath := filepath.Join("..", "..", "e2e", "v1", "dql", "dev", "district", "district_pagination.sql") + dqlPath, err := filepath.Abs(dqlPath) + require.NoError(t, err) + dqlBytes, err := os.ReadFile(dqlPath) + require.NoError(t, err) + + source := &shape.Source{ + Name: "district_pagination", + Path: dqlPath, + DQL: string(dqlBytes), + } + planResult, err := shapeCompile.New().Compile( + context.Background(), + source, + shape.WithLinkedTypes(false), + shape.WithTypeContextPackageDir(filepath.Join("e2e", "v1", "shape", "dev", "district", "pagination")), + shape.WithTypeContextPackageName("pagination"), + ) + require.NoError(t, err) + + componentArtifact, err := shapeLoad.New().LoadComponent(context.Background(), planResult, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + + component, ok := shapeLoad.ComponentFrom(componentArtifact) + require.True(t, ok) + require.NotNil(t, component) + + root, err := componentArtifact.Resource.Views.Index().Lookup(component.RootView) + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Schema) + assert.Equal(t, "pagination", root.Schema.Package) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/district/pagination", root.Schema.PackagePath) + + require.Len(t, root.With, 1) + child := &root.With[0].Of.View + require.NotNil(t, child.Schema) + assert.Equal(t, "pagination", child.Schema.Package) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/district/pagination", child.Schema.PackagePath) + + _, err = initTypeRegistryForResource(componentArtifact.Resource) + require.NoError(t, err) +} + +func initTypeRegistryForResource(resource *view.Resource) (*xreflect.Types, error) { + registry := extension.NewRegistry() + imports := view.Imports{} + for _, definition := range resource.Types { + if definition != nil && definition.ModulePath != "" { + imports.Add(definition.ModulePath) + if definition.Package != "" { + imports.AddWithAlias(definition.Package, definition.ModulePath) + } + } + } + for _, definition := range resource.Types { + if definition == nil { + continue + } + if err := definition.Init(context.Background(), registry.Types.Lookup, imports); err != nil { + return nil, err + } + if err := registry.Types.Register(definition.Name, xreflect.WithReflectType(definition.Type())); err != nil { + return nil, err + } + if definition.Package != "" { + if err := registry.Types.Register(definition.Name, xreflect.WithPackage(definition.Package), xreflect.WithReflectType(definition.Type())); err != nil { + return nil, err + } + } + } + return registry.Types, nil +} diff --git a/repository/shape/gorouter/discover.go b/repository/shape/gorouter/discover.go new file mode 100644 index 000000000..ddab0b597 --- /dev/null +++ b/repository/shape/gorouter/discover.go @@ -0,0 +1,671 @@ +package gorouter + +import ( + "bufio" + "context" + "fmt" + "go/ast" + "go/token" + "io/fs" + "os" + "path" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/componenttag" + "github.com/viant/datly/view/extension" + tagtags "github.com/viant/tagly/tags" + "github.com/viant/x" + "github.com/viant/xreflect" + "golang.org/x/tools/go/packages" +) + +// Discover scans Go packages for router holders and returns one route source per component-tagged field. +func Discover(ctx context.Context, baseDir string, include, exclude []string) ([]*RouteSource, error) { + baseDir = strings.TrimSpace(baseDir) + if baseDir == "" { + return nil, fmt.Errorf("go router discovery: base dir was empty") + } + if len(include) == 0 { + return nil, fmt.Errorf("go router discovery: include package patterns were empty") + } + include, err := expandPackagePatterns(ctx, baseDir, include) + if err != nil { + return nil, err + } + if len(include) == 0 { + return nil, fmt.Errorf("go router discovery: no packages matched include patterns") + } + loadCfg := &packages.Config{ + Context: ctx, + Dir: baseDir, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedSyntax, + } + pkgs, err := packages.Load(loadCfg, include...) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to load packages: %w", err) + } + index, err := newPackageIndex(ctx, baseDir) + if err != nil { + return nil, err + } + var result []*RouteSource + for _, pkg := range pkgs { + if pkg == nil || pkg.PkgPath == "" || matchesPackagePatternList(pkg.PkgPath, exclude) { + continue + } + dir := firstPackageDir(pkg) + if dir == "" { + continue + } + for i, file := range pkg.Syntax { + if file == nil || i >= len(pkg.GoFiles) { + continue + } + filePath := pkg.GoFiles[i] + imports := importMap(file) + discovered, err := index.routesInFile(pkg.PkgPath, pkg.Name, dir, filePath, file, imports) + if err != nil { + return nil, err + } + result = append(result, discovered...) + } + } + sort.SliceStable(result, func(i, j int) bool { + if result[i].PackagePath == result[j].PackagePath { + if result[i].FilePath == result[j].FilePath { + return result[i].FieldName < result[j].FieldName + } + return result[i].FilePath < result[j].FilePath + } + return result[i].PackagePath < result[j].PackagePath + }) + return result, nil +} + +func expandPackagePatterns(ctx context.Context, baseDir string, patterns []string) ([]string, error) { + unique := map[string]bool{} + var result []string + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if pattern == "" { + continue + } + expanded, err := expandPackagePattern(ctx, baseDir, pattern) + if err != nil { + return nil, err + } + for _, item := range expanded { + item = strings.TrimSpace(item) + if item == "" || unique[item] { + continue + } + unique[item] = true + result = append(result, item) + } + } + sort.Strings(result) + return result, nil +} + +func expandPackagePattern(ctx context.Context, baseDir, pattern string) ([]string, error) { + if !strings.HasSuffix(pattern, "/...") { + return []string{pattern}, nil + } + moduleDir, modulePath, err := locateModule(baseDir) + if err == nil { + if packages, ok, expandErr := expandModuleWildcardPattern(moduleDir, modulePath, pattern); expandErr != nil { + return nil, expandErr + } else if ok { + return packages, nil + } + } + cfg := &packages.Config{ + Context: ctx, + Dir: baseDir, + Mode: packages.NeedName | packages.NeedFiles, + } + pkgs, err := packages.Load(cfg, pattern) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to expand package pattern %s: %w", pattern, err) + } + var result []string + unique := map[string]bool{} + for _, pkg := range pkgs { + if pkg == nil || pkg.PkgPath == "" || unique[pkg.PkgPath] { + continue + } + unique[pkg.PkgPath] = true + result = append(result, pkg.PkgPath) + } + sort.Strings(result) + return result, nil +} + +func expandModuleWildcardPattern(moduleDir, modulePath, pattern string) ([]string, bool, error) { + prefix := strings.TrimSuffix(strings.TrimSpace(pattern), "/...") + if prefix == "" || moduleDir == "" || modulePath == "" { + return nil, false, nil + } + if prefix != modulePath && !strings.HasPrefix(prefix, modulePath+"/") { + return nil, false, nil + } + rel := strings.TrimPrefix(prefix, modulePath) + rel = strings.TrimPrefix(rel, "/") + rootDir := moduleDir + if rel != "" { + rootDir = filepath.Join(moduleDir, filepath.FromSlash(rel)) + } + info, err := os.Stat(rootDir) + if err != nil { + if os.IsNotExist(err) { + return nil, true, nil + } + return nil, true, err + } + if !info.IsDir() { + return nil, true, nil + } + unique := map[string]bool{} + var result []string + err = filepath.WalkDir(rootDir, func(current string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if !d.IsDir() { + return nil + } + name := d.Name() + if strings.HasPrefix(name, ".") || name == "testdata" { + if current != rootDir { + return filepath.SkipDir + } + return nil + } + hasGo, err := containsPackageGoFiles(current) + if err != nil { + return err + } + if !hasGo { + return nil + } + relDir, err := filepath.Rel(moduleDir, current) + if err != nil { + return err + } + importPath := modulePath + if relDir != "." { + importPath += "/" + filepath.ToSlash(relDir) + } + if !unique[importPath] { + unique[importPath] = true + result = append(result, importPath) + } + return nil + }) + if err != nil { + return nil, true, err + } + sort.Strings(result) + return result, true, nil +} + +func containsPackageGoFiles(dir string) (bool, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return false, err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { + continue + } + return true, nil + } + return false, nil +} + +func locateModule(baseDir string) (string, string, error) { + dir := filepath.Clean(baseDir) + for { + goModPath := filepath.Join(dir, "go.mod") + data, err := os.ReadFile(goModPath) + if err == nil { + modulePath := parseModulePath(data) + if modulePath == "" { + return "", "", fmt.Errorf("go router discovery: module path not found in %s", goModPath) + } + return dir, modulePath, nil + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return "", "", fmt.Errorf("go router discovery: go.mod not found from %s", baseDir) +} + +func parseModulePath(data []byte) string { + scanner := bufio.NewScanner(strings.NewReader(string(data))) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "//") { + continue + } + if strings.HasPrefix(line, "module ") { + return strings.TrimSpace(strings.TrimPrefix(line, "module ")) + } + } + return "" +} + +type packageIndex struct { + ctx context.Context + baseDir string + pkgs map[string]*packageMeta + dirTypes map[string]*xreflect.DirTypes +} + +type packageMeta struct { + importPath string + name string + dir string +} + +func newPackageIndex(ctx context.Context, baseDir string) (*packageIndex, error) { + return &packageIndex{ + ctx: ctx, + baseDir: baseDir, + pkgs: map[string]*packageMeta{}, + dirTypes: map[string]*xreflect.DirTypes{}, + }, nil +} + +func (p *packageIndex) routesInFile(pkgPath, pkgName, pkgDir, filePath string, file *ast.File, imports map[string]string) ([]*RouteSource, error) { + var result []*RouteSource + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + continue + } + for _, field := range structType.Fields.List { + route, err := p.routeFromField(pkgPath, pkgName, pkgDir, filePath, field, imports) + if err != nil { + return nil, err + } + if route != nil { + result = append(result, route) + } + } + } + } + return result, nil +} + +func (p *packageIndex) routeFromField(pkgPath, pkgName, pkgDir, filePath string, field *ast.Field, imports map[string]string) (*RouteSource, error) { + if field == nil || field.Tag == nil || len(field.Names) == 0 { + return nil, nil + } + tagLiteral, err := strconv.Unquote(field.Tag.Value) + if err != nil { + return nil, fmt.Errorf("go router discovery: invalid struct tag in %s: %w", filePath, err) + } + parsed, err := componenttag.Parse(reflect.StructTag(tagLiteral)) + if err != nil { + return nil, fmt.Errorf("go router discovery: invalid component tag in %s: %w", filePath, err) + } + if parsed == nil || parsed.Component == nil { + return nil, nil + } + fieldName := strings.TrimSpace(field.Names[0].Name) + if fieldName == "" { + return nil, nil + } + inputRef := normalizeTypeRef(strings.TrimSpace(parsed.Component.Input), pkgPath) + outputRef := normalizeTypeRef(strings.TrimSpace(parsed.Component.Output), pkgPath) + viewRef := normalizeTypeRef(strings.TrimSpace(parsed.Component.View), pkgPath) + handlerRef := normalizeTypeRef(strings.TrimSpace(parsed.Component.Handler), pkgPath) + sourceURL := strings.TrimSpace(parsed.Component.Source) + summaryURL := strings.TrimSpace(parsed.Component.Summary) + if inputRef == "" || outputRef == "" { + inferredInput, inferredOutput := inferComponentTypeRefs(field.Type, pkgPath, imports) + if inputRef == "" { + inputRef = inferredInput + } + if outputRef == "" { + outputRef = inferredOutput + } + } + if inputRef == "" && outputRef == "" { + if viewRef == "" && sourceURL == "" { + return nil, nil + } + } + if inputRef == "" && outputRef == "" && viewRef == "" { + return nil, nil + } + registry := x.NewRegistry() + tagCopy := *parsed.Component + if inputRef != "" { + rType, err := p.resolveType(inputRef) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to resolve %s input %s: %w", fieldName, inputRef, err) + } + registerType(registry, inputRef, rType) + tagCopy.Input = inputRef + } + if outputRef != "" { + rType, err := p.resolveType(outputRef) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to resolve %s output %s: %w", fieldName, outputRef, err) + } + registerType(registry, outputRef, rType) + tagCopy.Output = outputRef + } + if viewRef != "" { + rType, err := p.resolveType(viewRef) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to resolve %s view %s: %w", fieldName, viewRef, err) + } + registerType(registry, viewRef, rType) + tagCopy.View = viewRef + } + if handlerRef != "" { + rType, err := p.resolveType(handlerRef) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to resolve %s handler %s: %w", fieldName, handlerRef, err) + } + registerType(registry, handlerRef, rType) + tagCopy.Handler = handlerRef + } + if sourceURL != "" { + tagCopy.Source = sourceURL + } + if summaryURL != "" { + tagCopy.Summary = summaryURL + } + componentTag := tagCopy.Tag() + rootType := reflect.StructOf([]reflect.StructField{{ + Name: exportName(fieldName), + Type: reflect.TypeOf(struct{}{}), + Tag: reflect.StructTag(tagtags.Tags{componentTag}.Stringify()), + }}) + name := strings.TrimSpace(tagCopy.Name) + if name == "" { + name = exportName(fieldName) + } + return &RouteSource{ + Name: name, + FieldName: fieldName, + FilePath: filePath, + PackageName: pkgName, + PackagePath: pkgPath, + PackageDir: pkgDir, + RoutePath: strings.TrimSpace(tagCopy.Path), + Method: strings.TrimSpace(tagCopy.Method), + Connector: strings.TrimSpace(tagCopy.Connector), + InputRef: inputRef, + OutputRef: outputRef, + ViewRef: viewRef, + SourceURL: sourceURL, + SummaryURL: summaryURL, + Source: &shape.Source{ + Name: name, + Path: filePath, + Type: rootType, + TypeRegistry: registry, + }, + }, nil +} + +func (p *packageIndex) resolveType(ref string) (reflect.Type, error) { + pkgPath, typeName := splitTypeRef(ref) + if pkgPath == "" || typeName == "" { + return nil, fmt.Errorf("invalid type reference %q", ref) + } + if extension.Config != nil && extension.Config.Types != nil { + if linked, err := extension.Config.Types.Lookup(typeName, xreflect.WithPackage(pkgPath)); err == nil && linked != nil { + return linked, nil + } + } + meta, err := p.packageMeta(pkgPath) + if err != nil { + return nil, err + } + dirTypes, err := p.dirTypesFor(meta.dir) + if err != nil { + return nil, err + } + rType, err := dirTypes.Type(typeName) + if err != nil { + return nil, err + } + return rType, nil +} + +func (p *packageIndex) packageMeta(importPath string) (*packageMeta, error) { + if meta, ok := p.pkgs[importPath]; ok { + return meta, nil + } + cfg := &packages.Config{ + Context: p.ctx, + Dir: p.baseDir, + Mode: packages.NeedName | packages.NeedFiles, + } + pkgs, err := packages.Load(cfg, importPath) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to load package %s: %w", importPath, err) + } + for _, pkg := range pkgs { + if pkg == nil || pkg.PkgPath == "" { + continue + } + dir := firstPackageDir(pkg) + if dir == "" { + continue + } + meta := &packageMeta{importPath: pkg.PkgPath, name: pkg.Name, dir: dir} + p.pkgs[pkg.PkgPath] = meta + if pkg.PkgPath == importPath { + return meta, nil + } + } + return nil, fmt.Errorf("go router discovery: package %s not resolved", importPath) +} + +func (p *packageIndex) dirTypesFor(dir string) (*xreflect.DirTypes, error) { + if cached, ok := p.dirTypes[dir]; ok { + return cached, nil + } + options := []xreflect.Option{} + if extension.Config != nil && extension.Config.Types != nil { + options = append(options, xreflect.WithTypeLookup(extension.Config.Types.Lookup)) + } + parsed, err := xreflect.ParseTypes(dir, options...) + if err != nil { + return nil, fmt.Errorf("go router discovery: failed to parse package dir %s: %w", dir, err) + } + p.dirTypes[dir] = parsed + return parsed, nil +} + +func inferComponentTypeRefs(expr ast.Expr, pkgPath string, imports map[string]string) (string, string) { + args := componentTypeArgs(expr) + if len(args) < 2 { + return "", "" + } + return qualifyTypeExpr(args[0], pkgPath, imports), qualifyTypeExpr(args[1], pkgPath, imports) +} + +func componentTypeArgs(expr ast.Expr) []ast.Expr { + switch actual := expr.(type) { + case *ast.IndexListExpr: + if !isComponentSelector(actual.X) { + return nil + } + return actual.Indices + case *ast.IndexExpr: + if !isComponentSelector(actual.X) { + return nil + } + return []ast.Expr{actual.Index} + default: + return nil + } +} + +func isComponentSelector(expr ast.Expr) bool { + switch actual := expr.(type) { + case *ast.SelectorExpr: + return actual.Sel != nil && actual.Sel.Name == "Component" + case *ast.Ident: + return actual.Name == "Component" + default: + return false + } +} + +func qualifyTypeExpr(expr ast.Expr, pkgPath string, imports map[string]string) string { + switch actual := expr.(type) { + case *ast.Ident: + if pkgPath == "" || actual.Name == "" { + return "" + } + return pkgPath + "." + actual.Name + case *ast.SelectorExpr: + ident, ok := actual.X.(*ast.Ident) + if !ok || ident.Name == "" || actual.Sel == nil || actual.Sel.Name == "" { + return "" + } + importPath := imports[ident.Name] + if importPath == "" { + return "" + } + return importPath + "." + actual.Sel.Name + default: + return "" + } +} + +func importMap(file *ast.File) map[string]string { + result := map[string]string{} + if file == nil { + return result + } + for _, item := range file.Imports { + if item == nil || item.Path == nil { + continue + } + importPath, err := strconv.Unquote(item.Path.Value) + if err != nil || importPath == "" { + continue + } + alias := path.Base(importPath) + if item.Name != nil && strings.TrimSpace(item.Name.Name) != "" { + alias = strings.TrimSpace(item.Name.Name) + } + result[alias] = importPath + } + return result +} + +func splitTypeRef(ref string) (string, string) { + ref = strings.TrimSpace(ref) + if ref == "" { + return "", "" + } + index := strings.LastIndex(ref, ".") + if index == -1 || index+1 >= len(ref) { + return "", "" + } + return strings.TrimSpace(ref[:index]), strings.TrimSpace(ref[index+1:]) +} + +func normalizeTypeRef(ref, pkgPath string) string { + ref = strings.TrimSpace(ref) + if ref == "" { + return "" + } + if strings.Contains(ref, ".") { + return ref + } + if pkgPath == "" { + return ref + } + return pkgPath + "." + ref +} + +func registerType(registry *x.Registry, ref string, rType reflect.Type) { + if registry == nil || rType == nil { + return + } + pkgPath, typeName := splitTypeRef(ref) + registry.Register(x.NewType(rType, x.WithPkgPath(pkgPath), x.WithName(typeName))) +} + +func firstPackageDir(pkg *packages.Package) string { + if pkg == nil { + return "" + } + for _, filePath := range pkg.GoFiles { + if filePath == "" { + continue + } + return filepath.Dir(filePath) + } + return "" +} + +func exportName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "Route" + } + runes := []rune(name) + if len(runes) == 0 { + return "Route" + } + if runes[0] >= 'a' && runes[0] <= 'z' { + runes[0] = runes[0] - 32 + } + return string(runes) +} + +func matchesPackagePatternList(pkg string, patterns []string) bool { + for _, pattern := range patterns { + if matchesPackagePattern(pkg, pattern) { + return true + } + } + return false +} + +func matchesPackagePattern(pkg, pattern string) bool { + pkg = strings.TrimSpace(pkg) + pattern = strings.TrimSpace(pattern) + if pkg == "" || pattern == "" { + return false + } + if strings.HasSuffix(pattern, "/...") { + prefix := strings.TrimSuffix(pattern, "/...") + return pkg == prefix || strings.HasPrefix(pkg, prefix+"/") + } + return pkg == pattern +} diff --git a/repository/shape/gorouter/discover_test.go b/repository/shape/gorouter/discover_test.go new file mode 100644 index 000000000..5a21b2d64 --- /dev/null +++ b/repository/shape/gorouter/discover_test.go @@ -0,0 +1,266 @@ +package gorouter + +import ( + "context" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/viant/datly/view/extension" + "github.com/viant/xreflect" +) + +func TestDiscover_MultiFieldRouters(t *testing.T) { + baseDir := t.TempDir() + writeFile(t, filepath.Join(baseDir, "go.mod"), "module example.com/app\n\ngo 1.24\n") + writeFile(t, filepath.Join(baseDir, "pkg", "routes", "routes.go"), `package routes + +type ReportView struct { + ID int `+"`"+`sqlx:"ID"`+"`"+` +} + +type ReportInput struct { + ID int `+"`"+`parameter:",kind=path,in=id"`+"`"+` +} + +type ReportOutput struct { + Data []*ReportView `+"`"+`parameter:",kind=output,in=view"`+"`"+` +} + +type CreateInput struct { + Name string `+"`"+`parameter:",kind=body,in=name"`+"`"+` +} + +type CreateOutput struct { + Status string `+"`"+`parameter:",kind=output,in=status"`+"`"+` +} + +type Router struct { + Report struct{} `+"`"+`component:",path=/v1/report/{id},method=GET,input=ReportInput,output=ReportOutput"`+"`"+` + Create struct{} `+"`"+`component:",path=/v1/report,method=POST,input=CreateInput,output=CreateOutput"`+"`"+` +} +`) + + routes, err := Discover(context.Background(), baseDir, []string{"example.com/app/pkg/..."}, nil) + if err != nil { + t.Fatalf("discover failed: %v", err) + } + if len(routes) != 2 { + t.Fatalf("unexpected route count: %d", len(routes)) + } + if routes[0].InputRef == "" || routes[0].OutputRef == "" { + t.Fatalf("expected fully-qualified contract refs, but had %#v", routes[0]) + } + if routes[0].Source == nil || routes[0].Source.TypeRegistry == nil { + t.Fatalf("expected synthetic source with registry") + } + if routes[0].Source.Type == nil { + t.Fatalf("expected synthetic root type") + } +} + +func TestDiscover_ExcludePattern(t *testing.T) { + baseDir := t.TempDir() + writeFile(t, filepath.Join(baseDir, "go.mod"), "module example.com/app\n\ngo 1.24\n") + writeFile(t, filepath.Join(baseDir, "pkg", "one", "one.go"), "package one\ntype In struct{}\ntype Out struct{}\ntype Router struct { Route struct{} `component:\",path=/one,method=GET,input=In,output=Out\"` }\n") + writeFile(t, filepath.Join(baseDir, "pkg", "two", "two.go"), "package two\ntype In struct{}\ntype Out struct{}\ntype Router struct { Route struct{} `component:\",path=/two,method=GET,input=In,output=Out\"` }\n") + + routes, err := Discover(context.Background(), baseDir, []string{"example.com/app/pkg/..."}, []string{"example.com/app/pkg/two"}) + if err != nil { + t.Fatalf("discover failed: %v", err) + } + if len(routes) != 1 { + t.Fatalf("unexpected route count: %d", len(routes)) + } + if routes[0].RoutePath != "/one" { + t.Fatalf("unexpected route path: %s", routes[0].RoutePath) + } +} + +func TestDiscover_WildcardIncludesVendorSubtree(t *testing.T) { + baseDir := t.TempDir() + writeFile(t, filepath.Join(baseDir, "go.mod"), "module example.com/app\n\ngo 1.24\n") + writeFile(t, filepath.Join(baseDir, "shape", "dev", "vendor", "list", "vendor.go"), `package list + +type VendorInput struct { + ID int `+"`"+`parameter:",kind=path,in=id"`+"`"+` +} + +type VendorOutput struct { +} + +type VendorRouter struct { + Vendor struct{} `+"`"+`component:",path=/v1/vendors/{id},method=GET,input=VendorInput,output=VendorOutput"`+"`"+` +} +`) + writeFile(t, filepath.Join(baseDir, "shape", "dev", "team", "delete", "team.go"), `package delete + +type TeamInput struct { + ID int `+"`"+`parameter:",kind=path,in=id"`+"`"+` +} + +type TeamOutput struct{} + +type TeamRouter struct { + Team struct{} `+"`"+`component:",path=/v1/team/{id},method=DELETE,input=TeamInput,output=TeamOutput"`+"`"+` +} +`) + + routes, err := Discover(context.Background(), baseDir, []string{"example.com/app/shape/dev/..."}, nil) + if err != nil { + t.Fatalf("discover failed: %v", err) + } + if len(routes) != 2 { + t.Fatalf("unexpected route count: %d", len(routes)) + } + foundVendor := false + for _, route := range routes { + if route != nil && route.PackagePath == "example.com/app/shape/dev/vendor/list" { + foundVendor = true + break + } + } + if !foundVendor { + t.Fatalf("expected vendor subtree package to be discovered, got %#v", routes) + } +} + +func TestDiscover_ResolvesImportedEmbeddedOutputType(t *testing.T) { + extension.InitRegistry() + baseDir := t.TempDir() + writeFile(t, filepath.Join(baseDir, "go.mod"), "module example.com/app\n\ngo 1.24\n") + writeFile(t, filepath.Join(baseDir, "pkg", "routes", "routes.go"), `package routes + +import "github.com/viant/xdatly/handler/response" + +type ReportInput struct { + ID int `+"`"+`parameter:",kind=path,in=id"`+"`"+` +} + +type ReportOutput struct { + response.Status `+"`"+`parameter:",kind=output,in=status"`+"`"+` +} + +type Router struct { + Report struct{} `+"`"+`component:",path=/v1/report/{id},method=GET,input=ReportInput,output=ReportOutput"`+"`"+` +} +`) + + routes, err := Discover(context.Background(), baseDir, []string{"example.com/app/pkg/..."}, nil) + if err != nil { + t.Fatalf("discover failed: %v", err) + } + if len(routes) != 1 { + t.Fatalf("unexpected route count: %d", len(routes)) + } + if routes[0].OutputRef != "example.com/app/pkg/routes.ReportOutput" { + t.Fatalf("unexpected output ref: %s", routes[0].OutputRef) + } +} + +func TestDiscover_NormalizesHandlerType(t *testing.T) { + baseDir := t.TempDir() + writeFile(t, filepath.Join(baseDir, "go.mod"), "module example.com/app\n\ngo 1.24\n") + writeFile(t, filepath.Join(baseDir, "pkg", "routes", "routes.go"), `package routes + +import ( + "context" + xhandler "github.com/viant/xdatly/handler" +) + +type ReportInput struct { + ID int `+"`"+`parameter:",kind=path,in=id"`+"`"+` +} + +type ReportOutput struct { + OK bool `+"`"+`parameter:",kind=output,in=view"`+"`"+` +} + +type Handler struct{} + +func (h *Handler) Exec(ctx context.Context, sess xhandler.Session) (interface{}, error) { + return ReportOutput{OK: true}, nil +} + +type Router struct { + Report struct{} `+"`"+`component:",path=/v1/report/{id},method=GET,input=ReportInput,output=ReportOutput,handler=Handler"`+"`"+` +} +`) + + routes, err := Discover(context.Background(), baseDir, []string{"example.com/app/pkg/..."}, nil) + if err != nil { + t.Fatalf("discover failed: %v", err) + } + if len(routes) != 1 { + t.Fatalf("unexpected route count: %d", len(routes)) + } + tag := routes[0].Source.Type.Field(0).Tag.Get("component") + if tag == "" || filepath.Base(routes[0].PackagePath) == "" { + t.Fatalf("expected route component tag to be present") + } + if got := routes[0].Source.Type.Field(0).Tag.Get("component"); !strings.Contains(got, "handler=example.com/app/pkg/routes.Handler") { + t.Fatalf("expected normalized handler ref in component tag, got %q", got) + } +} + +func TestDiscover_PrefersLinkedNamedType(t *testing.T) { + extension.InitRegistry() + type linkedReportView struct { + ID int `sqlx:"ID"` + Name string `sqlx:"NAME"` + } + if err := extension.Config.Types.Register("ReportView", + xreflect.WithPackage("example.com/app/pkg/routes"), + xreflect.WithReflectType(reflect.TypeOf(linkedReportView{})), + ); err != nil { + t.Fatalf("register linked type: %v", err) + } + + baseDir := t.TempDir() + writeFile(t, filepath.Join(baseDir, "go.mod"), "module example.com/app\n\ngo 1.24\n") + writeFile(t, filepath.Join(baseDir, "pkg", "routes", "routes.go"), `package routes + +type ReportView struct { + Items []*struct { + ID int `+"`"+`sqlx:"ID"`+"`"+` + } `+"`"+`view:",table=ITEM"`+"`"+` +} + +type ReportInput struct{} + +type ReportOutput struct { + Data *ReportView `+"`"+`parameter:",kind=output,in=view"`+"`"+` +} + +type Router struct { + Report struct{} `+"`"+`component:",path=/v1/report,method=GET,input=ReportInput,output=ReportOutput,view=ReportView"`+"`"+` +} +`) + + routes, err := Discover(context.Background(), baseDir, []string{"example.com/app/pkg/..."}, nil) + if err != nil { + t.Fatalf("discover failed: %v", err) + } + if len(routes) != 1 { + t.Fatalf("unexpected route count: %d", len(routes)) + } + lookup := routes[0].Source.TypeRegistry.Lookup("example.com/app/pkg/routes.ReportView") + if lookup == nil || lookup.Type == nil { + t.Fatalf("expected route view to be registered") + } + if got, want := lookup.Type, reflect.TypeOf(linkedReportView{}); got != want { + t.Fatalf("expected linked route view type %v, got %v", want, got) + } +} + +func writeFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("write failed: %v", err) + } +} diff --git a/repository/shape/gorouter/model.go b/repository/shape/gorouter/model.go new file mode 100644 index 000000000..d93bd2ea5 --- /dev/null +++ b/repository/shape/gorouter/model.go @@ -0,0 +1,22 @@ +package gorouter + +import "github.com/viant/datly/repository/shape" + +// RouteSource represents one component route field discovered from a Go source package. +type RouteSource struct { + Name string + FieldName string + FilePath string + PackageName string + PackagePath string + PackageDir string + RoutePath string + Method string + Connector string + InputRef string + OutputRef string + ViewRef string + SourceURL string + SummaryURL string + Source *shape.Source +} diff --git a/repository/shape/improvement.md b/repository/shape/improvement.md new file mode 100644 index 000000000..aceb6f67d --- /dev/null +++ b/repository/shape/improvement.md @@ -0,0 +1,267 @@ +# Shape Improvement Proposal + +This note captures the main internal improvements suggested by the translator-to-shape migration work. + +Scope: + +- Applies to `repository/shape` +- Focuses on `DQL -> shape -> IR` +- Uses migration findings from grouping, summary, selector, and generated patch routes + +## Goals + +- Make shape the authoritative semantic model for DQL and Go-derived routes. +- Reduce runtime/bootstrap recovery logic. +- Replace translator-era implicit behavior with explicit shape metadata. +- Keep load/materialization modular so views, components, and resources can be built independently. + +## 1. Promote `ComponentRoute` As A First-Class Shape Primitive + +Observed gap: + +- Route path, method, template strategy, and related component-level metadata were historically reconstructed outside shape. +- That encouraged direct `DQL -> IR` workarounds. + +Proposal: + +- Treat `ComponentRoute` as a first-class primitive produced by DQL compile. +- Carry at minimum: + - `Method` + - `RoutePath` + - `TemplateType` + - route-level connector/defaults + - route-level metadata/docs/cache/auth flags + +Target: + +- `DQL -> ComponentRoute` +- `DQL -> View` +- `shape/load -> component/resource IR` + +Benefit: + +- Transcribe, bootstrap, and future `AddRoute` APIs can consume route metadata from shape only. + +## 2. Make Template Strategy Explicit + +Observed gap: + +- Generated patch routes and translated exec routes were previously distinguished by inference. +- That was fragile and led to runtime fallbacks. + +Proposal: + +- Keep explicit DQL settings such as: + - `#setting($_ = $useTemplate('patch'))` +- Store the resolved value on shape route metadata as `TemplateType`. + +Recommended semantics: + +- `translate` or empty: preserve authored DQL/Velty behavior +- `patch`: synthesize mutable Velty from shape AST/metadata +- future values may include `post`, `put`, `upsert` + +Benefit: + +- Removes heuristic detection of generated mutable routes. + +## 3. Eliminate Runtime Type Recovery For Helper Parameters + +Observed gap: + +- Generated patch helpers such as `CurFoosId` and `CurFoos` needed runtime recovery of source type information. +- Local generator paths were more explicit than the early `v1` shape/transcribe path. + +Proposal: + +- Shape/transcribe should emit enough source/output type metadata so runtime codec initialization does not need to infer types from referenced params. +- Helper params should carry explicit source owner type and output type in shape/IR. + +Benefit: + +- Moves correctness back into shape. +- Reduces special handling in `view/state/parameter.go` and codec initialization. + +## 4. Keep View-Level And Column-Level Semantics Separate + +Observed gap: + +- Grouping and selector metadata were easy to blur across view and column layers. + +Proposal: + +- View-level metadata stays explicit: + - `Groupable` + - selector namespace + - selector constraints + - summary URI / summary behavior +- Column-level metadata is explicit or inferred independently: + - `ColumnConfig.Groupable` + - inferred grouped projections from `GROUP BY` + +Benefit: + +- Avoids deriving view semantics from column accidents. +- Keeps Go tags and DQL hints aligned. + +## 5. Add Dedicated Shape Primitives For Selector Holders + +Observed gap: + +- Flattening query selector fields into business input types makes Go shape contracts noisy and semantically wrong. + +Proposal: + +- Keep query-selector state as a separate shape concept. +- Support Go-derived contracts like: + - business input holder + - selector holder tagged with `querySelector:"viewAlias"` + +Target Go model: + +- `VendorInput` remains business input +- `ViewSelect` remains selector state +- shape merges both into component contract IR + +Benefit: + +- Aligns Go-derived shape with the DQL selector model. + +## 6. Make Summary A Real Shape Concept, Not A Side Effect + +Observed gap: + +- Summary handling drifted between tags, parent view attachment, and runtime conventions. +- Multi-level summaries exposed gaps in child summary attachment and typing. + +Proposal: + +- Represent summary explicitly in shape at any view level. +- Include: + - summary target view/ref + - summary URI/source + - parent attachment semantics + - summary output schema/type + +Benefit: + +- Root summaries and child summaries can be materialized consistently from shape. + +## 7. Add Recursive Mutable Generation For Nested Graphs + +Observed gap: + +- `patch_basic_one` and `patch_basic_many` became stable, but nested mutable graphs such as many-many flows need more general helper synthesis. + +Proposal: + +- Generalize mutable generation to recurse across relation graphs. +- Generate helper views, `IndexBy` maps, key propagation, and DML blocks per mutable node. + +Examples: + +- root collection patch +- nested child collection patch +- nested key propagation such as `FooId = parent.Id` + +Benefit: + +- Closes the remaining gap between local generate flows and shape-generated mutable routes. + +## 8. Introduce A Strong Shape Validation Stage + +Observed gap: + +- Some failures were discovered too late at bootstrap/runtime. + +Proposal: + +- Expand `datly validate` as the primary shape-only validation gate. +- Validate: + - DQL syntax and directives + - SQL asset existence + - route metadata completeness + - helper type completeness + - selector/summary/grouping consistency + - generated mutable prerequisites + +Benefit: + +- Detects incomplete shape before runtime. + +## 9. Add Deterministic Diagnostics Codes + +Observed gap: + +- Migration debugging spent too much time on ad hoc runtime errors. + +Proposal: + +- Extend shape diagnostics with stable codes across: + - route metadata + - selector metadata + - summary attachment + - mutable helper generation + - groupable inference + - type collisions + +Benefit: + +- Better tooling, tests, and compile-time failure handling. + +## 10. Keep Load Modular By Primitive + +Observed gap: + +- Some behavior was easier to validate once view/component/resource loading was separated. + +Proposal: + +- Continue building around primitive loaders: + - `LoadView` + - `LoadComponentRoute` + - `LoadComponent` + - `LoadResource` +- Keep both inputs supported: + - Go types + - DQL + +Benefit: + +- Allows future APIs such as `AddRoute` to stay thin. +- Makes unit coverage sharper and reduces cross-coupled runtime fixes. + +## 11. Reduce Direct Translator Dependence To Parity Specs And Fixtures + +Observed gap: + +- Migration often required checking local regression translator output to understand target semantics. + +Proposal: + +- Treat translator/local regression outputs as parity fixtures, not active implementation dependencies. +- Keep explicit parity docs and focused regression fixtures in shape tests. + +Benefit: + +- Shape remains independent while still preserving observable legacy behavior. + +## Suggested Priority + +1. `ComponentRoute` ownership in shape +2. Explicit `TemplateType` +3. Remove runtime helper type recovery by emitting complete helper metadata +4. Summary as explicit shape metadata +5. Recursive mutable generation +6. Broader `datly validate` coverage +7. Diagnostics standardization + +## Success Criteria + +The migration is structurally complete when: + +- Bootstrap and transcribe no longer need to reconstruct missing semantics from raw DQL. +- Generated patch routes are selected explicitly, not inferred heuristically. +- Summary, selector, and grouping behavior are fully representable in shape. +- Runtime does not need shape-recovery logic for helper/source types. +- Local translator outputs are matched by shape through tests, not through runtime workarounds. diff --git a/repository/shape/load/columns.go b/repository/shape/load/columns.go index 147a2a5b4..015e41efd 100644 --- a/repository/shape/load/columns.go +++ b/repository/shape/load/columns.go @@ -32,6 +32,9 @@ func inferColumnsFromType(rType reflect.Type) []*view.Column { if !f.IsExported() { continue } + if shouldSkipInferredField(f) { + continue + } colName := sqlxColumnName(f) if colName == "" { colName = f.Name @@ -44,6 +47,46 @@ func inferColumnsFromType(rType reflect.Type) []*view.Column { return cols } +func shouldSkipInferredField(field reflect.StructField) bool { + if field.Name == "-" { + return true + } + rawTag := string(field.Tag) + if strings.Contains(rawTag, `view:"`) || strings.Contains(rawTag, `on:"`) { + return true + } + if strings.Contains(rawTag, `sqlx:"-"`) { + return true + } + return false +} + +func inferredColumnsArePlaceholders(columns []*view.Column) bool { + if len(columns) == 0 { + return false + } + for _, column := range columns { + if column == nil || !isPlaceholderColumnName(column.Name) { + return false + } + } + return true +} + +func isPlaceholderColumnName(name string) bool { + name = strings.TrimSpace(strings.ToLower(name)) + name = strings.ReplaceAll(name, "_", "") + if !strings.HasPrefix(name, "col") || len(name) == len("col") { + return false + } + for i := len("col"); i < len(name); i++ { + if name[i] < '0' || name[i] > '9' { + return false + } + } + return true +} + // sqlxColumnName reads the sqlx struct tag to get the database column name. func sqlxColumnName(f reflect.StructField) string { tag := f.Tag.Get("sqlx") @@ -52,9 +95,15 @@ func sqlxColumnName(f reflect.StructField) string { } for _, part := range strings.Split(tag, ",") { part = strings.TrimSpace(part) + if part == "" { + continue + } if strings.HasPrefix(part, "name=") { return strings.TrimPrefix(part, "name=") } + if !strings.Contains(part, "=") { + return part + } } return "" } diff --git a/repository/shape/load/columns_test.go b/repository/shape/load/columns_test.go new file mode 100644 index 000000000..96a326745 --- /dev/null +++ b/repository/shape/load/columns_test.go @@ -0,0 +1,24 @@ +package load + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +type sampleInferredColumnsRoot struct { + ID int `sqlx:"ID"` + Products []*sampleInferredRel `view:",table=PRODUCT" on:"Id:ID=VendorId:VENDOR_ID"` + Ignored string `sqlx:"-"` +} + +type sampleInferredRel struct { + VendorID int `sqlx:"VENDOR_ID"` +} + +func TestInferColumnsFromType_SkipsSemanticFields(t *testing.T) { + cols := inferColumnsFromType(reflect.TypeOf(sampleInferredColumnsRoot{})) + require.Len(t, cols, 1) + require.Equal(t, "ID", cols[0].Name) +} diff --git a/repository/shape/load/loader.go b/repository/shape/load/loader.go index 03277feff..d5013ecb5 100644 --- a/repository/shape/load/loader.go +++ b/repository/shape/load/loader.go @@ -3,19 +3,34 @@ package load import ( "context" "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path" + "path/filepath" "reflect" + "regexp" + "sort" "strings" "time" "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/repository/shape/typectx" shapevalidate "github.com/viant/datly/repository/shape/validate" "github.com/viant/datly/shared" + "github.com/viant/datly/utils/types" "github.com/viant/datly/view" "github.com/viant/datly/view/extension" "github.com/viant/datly/view/state" + "github.com/viant/datly/view/tags" + "github.com/viant/sqlparser" + sqlxio "github.com/viant/sqlx/io" + "github.com/viant/tagly/format/text" + "github.com/viant/xdatly/handler/response" ) // Loader materializes runtime view artifacts from normalized shape plan. @@ -27,11 +42,17 @@ func New() *Loader { } // LoadViews implements shape.Loader. -func (l *Loader) LoadViews(ctx context.Context, planned *shape.PlanResult, _ ...shape.LoadOption) (*shape.ViewArtifacts, error) { +func (l *Loader) LoadViews(ctx context.Context, planned *shape.PlanResult, opts ...shape.LoadOption) (*shape.ViewArtifacts, error) { if err := ctx.Err(); err != nil { return nil, err } - pResult, resource, err := l.materialize(planned) + loadOptions := &shape.LoadOptions{} + for _, opt := range opts { + if opt != nil { + opt(loadOptions) + } + } + pResult, resource, err := l.materialize(ctx, planned, loadOptions) if err != nil { return nil, err } @@ -41,26 +62,80 @@ func (l *Loader) LoadViews(ctx context.Context, planned *shape.PlanResult, _ ... return &shape.ViewArtifacts{Resource: resource, Views: resource.Views}, nil } +// LoadResource implements shape.Loader. +func (l *Loader) LoadResource(ctx context.Context, planned *shape.PlanResult, opts ...shape.LoadOption) (*shape.ResourceArtifacts, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + loadOptions := &shape.LoadOptions{} + for _, opt := range opts { + if opt != nil { + opt(loadOptions) + } + } + _, resource, err := l.materialize(ctx, planned, loadOptions) + if err != nil { + return nil, err + } + return &shape.ResourceArtifacts{Resource: resource}, nil +} + // LoadComponent implements shape.Loader. -func (l *Loader) LoadComponent(ctx context.Context, planned *shape.PlanResult, _ ...shape.LoadOption) (*shape.ComponentArtifact, error) { +func (l *Loader) LoadComponent(ctx context.Context, planned *shape.PlanResult, opts ...shape.LoadOption) (*shape.ComponentArtifact, error) { if err := ctx.Err(); err != nil { return nil, err } - pResult, resource, err := l.materialize(planned) + loadOptions := &shape.LoadOptions{UseTypeContextPackages: true} + for _, opt := range opts { + if opt != nil { + opt(loadOptions) + } + } + pResult, resource, err := l.materialize(ctx, planned, loadOptions) if err != nil { return nil, err } + if err := validateComponentRoutes(pResult.Components); err != nil { + return nil, err + } if len(pResult.Views) == 0 { - return nil, ErrEmptyViewPlan + if err := materializeComponentRouteView(planned.Source, pResult, resource); err != nil { + return nil, err + } + if len(resource.Views) == 0 && !allowsViewlessComponent(pResult.Components) { + return nil, ErrEmptyViewPlan + } } - component := buildComponent(planned.Source, pResult) + component := buildComponent(planned.Source, pResult, resource, loadOptions) return &shape.ComponentArtifact{ Resource: resource, Component: component, }, nil } -func (l *Loader) materialize(planned *shape.PlanResult) (*plan.Result, *view.Resource, error) { +func validateComponentRoutes(routes []*plan.ComponentRoute) error { + count := 0 + for _, route := range routes { + if route != nil { + count++ + } + } + if count <= 1 { + return nil + } + return fmt.Errorf("shape load: multiple component routes are not supported for a single component artifact") +} + +func allowsViewlessComponent(routes []*plan.ComponentRoute) bool { + for _, route := range routes { + if route != nil { + return true + } + } + return false +} + +func (l *Loader) materialize(ctx context.Context, planned *shape.PlanResult, loadOptions *shape.LoadOptions) (*plan.Result, *view.Resource, error) { if planned == nil || planned.Source == nil { return nil, nil, shape.ErrNilSource } @@ -77,393 +152,4390 @@ func (l *Loader) materialize(planned *shape.PlanResult) (*plan.Result, *view.Res if err != nil { return nil, nil, err } + if loadOptions != nil && loadOptions.UseTypeContextPackages { + inheritViewSchemaPackage(aView, pResult.TypeContext) + } resource.AddViews(aView) } + materializeConcreteViewSchemas(resource, planned.Source, pResult.TypeContext) + refineViewColumnConfigTypes(resource, planned.Source, pResult.TypeContext) + enrichConcreteViewColumns(resource) + assignViewSummarySchemas(resource, pResult, planned.Source) + enrichRelationLinkFields(pResult.Views) + attachViewRelations(resource, pResult.Views) + if err := enrichRelationHolderTypes(resource, pResult.Views); err != nil { + return nil, nil, err + } + refineSummarySchemas(resource) + applyVeltyAliasesToExecInputViews(resource, pResult) + materializeResourceTypes(resource, pResult.Views, planned.Source, pResult.TypeContext) + applyVeltyAliasesToExecInputViews(resource, pResult) + rootView := rootResourceView(resource, pResult.Views) + for _, item := range pResult.States { + if item == nil { + continue + } + param := cloneStateParameter(item) + if param == nil { + continue + } + normalizeDerivedInputSchema(param, resource) + if rootView != nil { + inheritRootBodySchema(param, rootView) + } + if rootView != nil { + inheritRootOutputSchema(param, rootView) + } + ensureMaterializedOutputSchema(param, rootView, planned.Source, pResult.TypeContext) + addResourceParameter(resource, param) + } if err := shapevalidate.ValidateRelations(resource, resource.Views...); err != nil { return nil, nil, err } - // Gap 7: apply global cache TTL directive to root view. + if len(pResult.Const) > 0 { + for k, v := range pResult.Const { + constParam := &state.Parameter{ + Name: k, + In: state.NewConstLocation(k), + Value: v, + Tag: `internal:"true"`, + Schema: &state.Schema{ + Name: "string", + DataType: "string", + Cardinality: state.One, + }, + } + addResourceParameter(resource, constParam) + } + } + bindTemplateParameters(resource) + // Apply cache directives only as resource-level provider definitions. + // View-level cache binding comes from explicit view metadata such as set_cache(...). if pResult.Directives != nil && pResult.Directives.Cache != nil { - if ttl := strings.TrimSpace(pResult.Directives.Cache.TTL); ttl != "" { - if dur, err := time.ParseDuration(ttl); err == nil && dur > 0 { - ttlMs := int(dur.Milliseconds()) - if rootPlan := pickRootView(pResult.Views); rootPlan != nil { - for _, rv := range resource.Views { - if rv != nil && rv.Name == rootPlan.Name { - if rv.Cache == nil { - rv.Cache = &view.Cache{} - } - rv.Cache.TimeToLiveMs = ttlMs - break - } - } - } + if name := strings.TrimSpace(pResult.Directives.Cache.Name); name != "" { + provider := strings.TrimSpace(pResult.Directives.Cache.Provider) + location := strings.TrimSpace(pResult.Directives.Cache.Location) + ttlMs := pResult.Directives.Cache.TimeToLiveMs + if provider != "" && location != "" && ttlMs > 0 { + resource.CacheProviders = append(resource.CacheProviders, &view.Cache{ + Name: name, + Provider: provider, + Location: location, + TimeToLiveMs: ttlMs, + }) } } } return pResult, resource, nil } -func buildComponent(source *shape.Source, pResult *plan.Result) *Component { +func buildComponent(source *shape.Source, pResult *plan.Result, resource *view.Resource, loadOptions *shape.LoadOptions) *Component { component := &Component{Method: "GET"} if source != nil { component.Name = source.Name component.URI = source.Name } + component.TypeContext = cloneTypeContext(pResult.TypeContext) + applyComponentRoutes(component, pResult.Components) applyViewMeta(component, pResult.Views) - applyStateBuckets(component, pResult.States) + applyMutableRootMode(component, resource) + applyStateBuckets(component, pResult.States, resource, source, pResult.TypeContext, loadOptions) + applyStateBuckets(component, synthesizeConstStates(pResult.Const), resource, source, pResult.TypeContext, loadOptions) + applyStateBuckets(component, synthesizeMissingRouteContractStates(component, pResult.Components), resource, source, pResult.TypeContext, loadOptions) + synthesizeMutableExecHelpers(component, resource) component.Input = append(component.Input, synthesizePredicateStates(component.Input, component.Predicates)...) - component.TypeContext = cloneTypeContext(pResult.TypeContext) component.Directives = cloneDirectives(pResult.Directives) + if primary := firstComponentRoute(pResult.Components); primary != nil && primary.Report != nil { + component.Report = &dqlshape.ReportDirective{ + Enabled: primary.Report.Enabled, + Input: strings.TrimSpace(primary.Report.Input), + Dimensions: strings.TrimSpace(primary.Report.Dimensions), + Measures: strings.TrimSpace(primary.Report.Measures), + Filters: strings.TrimSpace(primary.Report.Filters), + OrderBy: strings.TrimSpace(primary.Report.OrderBy), + Limit: strings.TrimSpace(primary.Report.Limit), + Offset: strings.TrimSpace(primary.Report.Offset), + } + } else if component.Directives != nil && component.Directives.Report != nil { + component.Report = &dqlshape.ReportDirective{ + Enabled: component.Directives.Report.Enabled, + Input: strings.TrimSpace(component.Directives.Report.Input), + Dimensions: strings.TrimSpace(component.Directives.Report.Dimensions), + Measures: strings.TrimSpace(component.Directives.Report.Measures), + Filters: strings.TrimSpace(component.Directives.Report.Filters), + OrderBy: strings.TrimSpace(component.Directives.Report.OrderBy), + Limit: strings.TrimSpace(component.Directives.Report.Limit), + Offset: strings.TrimSpace(component.Directives.Report.Offset), + } + } component.ColumnsDiscovery = pResult.ColumnsDiscovery + component.TypeSpecs = resolveTypeSpecs(pResult) return component } -// applyViewMeta populates the component with view names, declarations, relations, -// query selectors, predicate maps, and root view from the plan view list. -func applyViewMeta(component *Component, views []*plan.View) { - for _, aView := range views { - if aView == nil { - continue - } - component.Views = append(component.Views, aView.Name) - if aView.Declaration != nil { - indexViewDeclaration(component, aView.Name, aView.Declaration) - } - if len(aView.Relations) > 0 { - component.Relations = append(component.Relations, aView.Relations...) - component.ViewRelations = append(component.ViewRelations, toViewRelations(aView.Relations)...) - } +func addResourceParameter(resource *view.Resource, param *state.Parameter) { + if resource == nil || param == nil { + return } - if rootView := pickRootView(views); rootView != nil { - component.RootView = rootView.Name - if component.Name == "" { - component.Name = rootView.Name - } + resource.AddParameters(param) + if named := resource.NamedParameters(); named != nil { + _ = named.Register(param) } } -// indexViewDeclaration registers the declaration's query selector and predicates -// on the component index maps, creating them on demand. -func indexViewDeclaration(component *Component, viewName string, decl *plan.ViewDeclaration) { - if component.Declarations == nil { - component.Declarations = map[string]*plan.ViewDeclaration{} +func applyComponentRoutes(component *Component, routes []*plan.ComponentRoute) { + if component == nil || len(routes) == 0 { + return } - component.Declarations[viewName] = decl - if selector := strings.TrimSpace(decl.QuerySelector); selector != "" { - if component.QuerySelectors == nil { - component.QuerySelectors = map[string][]string{} - } - component.QuerySelectors[selector] = append(component.QuerySelectors[selector], viewName) + component.ComponentRoutes = cloneComponentRoutes(routes) + primary := firstComponentRoute(routes) + if primary == nil { + return } - if len(decl.Predicates) > 0 { - if component.Predicates == nil { - component.Predicates = map[string][]*plan.ViewPredicate{} + if uri := strings.TrimSpace(primary.RoutePath); uri != "" { + component.URI = uri + if strings.TrimSpace(component.Name) == "" { + component.Name = uri } - component.Predicates[viewName] = append(component.Predicates[viewName], decl.Predicates...) } -} - -// applyStateBuckets sorts plan states into the typed buckets on the component -// (Input, Output, Meta, Async, Other) based on the state's location kind. -func applyStateBuckets(component *Component, states []*plan.State) { - for _, item := range states { - if item == nil { - continue - } - kind := state.Kind(strings.ToLower(item.KindString())) - inName := item.InName() - if kind == "" && inName == "" { - component.Other = append(component.Other, item) - continue - } - switch kind { - case state.KindQuery, state.KindPath, state.KindHeader, state.KindRequestBody, - state.KindForm, state.KindCookie, state.KindRequest, "": - component.Input = append(component.Input, item) - case state.KindOutput: - component.Output = append(component.Output, item) - case state.KindMeta: - component.Meta = append(component.Meta, item) - case state.KindAsync: - component.Async = append(component.Async, item) - default: - component.Other = append(component.Other, item) + if method := strings.TrimSpace(primary.Method); method != "" { + component.Method = method + } + if strings.TrimSpace(component.Name) == "" { + component.Name = strings.TrimSpace(primary.Name) + } + if strings.TrimSpace(component.RootView) == "" && strings.TrimSpace(primary.ViewName) != "" { + component.RootView = routeViewAlias(primary) + if component.RootView != "" { + component.Views = append(component.Views, component.RootView) } } } -// synthesizePredicateStates creates query parameters for view-level predicates whose -// source parameter is not already present in the input state list. -func synthesizePredicateStates(input []*plan.State, predicates map[string][]*plan.ViewPredicate) []*plan.State { - if len(predicates) == 0 { - return nil +func applyMutableRootMode(component *Component, resource *view.Resource) { + if component == nil || resource == nil { + return } - declared := make(map[string]bool, len(input)) - for _, s := range input { - if s != nil { - declared[strings.ToLower(strings.TrimPrefix(strings.TrimSpace(s.Name), "$"))] = true - } + if strings.EqualFold(strings.TrimSpace(component.Method), "GET") { + return } - var result []*plan.State - for _, viewPredicates := range predicates { - for _, vp := range viewPredicates { - if vp == nil { - continue - } - src := strings.TrimPrefix(strings.TrimSpace(vp.Source), "$") - if src == "" || declared[strings.ToLower(src)] { - continue - } - result = append(result, &plan.State{ - Parameter: state.Parameter{ - Name: src, - In: state.NewQueryLocation(src), - Schema: &state.Schema{DataType: "string"}, - Predicates: []*extension.PredicateConfig{ - { - Name: vp.Name, - Ensure: vp.Ensure, - Args: append([]string{}, vp.Arguments...), - }, - }, - }, - }) - declared[strings.ToLower(src)] = true - } + rootView := lookupNamedResourceView(resource, component.RootView) + if rootView == nil { + return + } + if rootView.Mode != view.ModeHandler { + rootView.Mode = view.ModeExec } - return result } -func cloneTypeContext(input *typectx.Context) *typectx.Context { - if input == nil { +func cloneComponentRoutes(routes []*plan.ComponentRoute) []*plan.ComponentRoute { + if len(routes) == 0 { return nil } - ret := &typectx.Context{ - DefaultPackage: strings.TrimSpace(input.DefaultPackage), - PackageDir: strings.TrimSpace(input.PackageDir), - PackageName: strings.TrimSpace(input.PackageName), - PackagePath: strings.TrimSpace(input.PackagePath), - } - for _, item := range input.Imports { - pkg := strings.TrimSpace(item.Package) - if pkg == "" { + result := make([]*plan.ComponentRoute, 0, len(routes)) + for _, item := range routes { + if item == nil { continue } - ret.Imports = append(ret.Imports, typectx.Import{ - Alias: strings.TrimSpace(item.Alias), - Package: pkg, - }) + cloned := *item + result = append(result, &cloned) } - if ret.DefaultPackage == "" && - len(ret.Imports) == 0 && - ret.PackageDir == "" && - ret.PackageName == "" && - ret.PackagePath == "" { + if len(result) == 0 { return nil } - return ret + return result } -func cloneDirectives(input *dqlshape.Directives) *dqlshape.Directives { - if input == nil { - return nil +func synthesizeMutableExecHelpers(component *Component, resource *view.Resource) { + if component == nil || resource == nil { + return } - ret := &dqlshape.Directives{ - Meta: strings.TrimSpace(input.Meta), - DefaultConnector: strings.TrimSpace(input.DefaultConnector), + if strings.EqualFold(strings.TrimSpace(component.Method), "GET") { + return } - if input.Cache != nil { - ret.Cache = &dqlshape.CacheDirective{ - Enabled: input.Cache.Enabled, - TTL: strings.TrimSpace(input.Cache.TTL), - } + if templateType := strings.ToLower(strings.TrimSpace(componentTemplateType(component))); templateType != "" && templateType != "patch" { + return } - if input.MCP != nil { - ret.MCP = &dqlshape.MCPDirective{ - Name: strings.TrimSpace(input.MCP.Name), - Description: strings.TrimSpace(input.MCP.Description), - DescriptionPath: strings.TrimSpace(input.MCP.DescriptionPath), - } + rootView := lookupNamedResourceView(resource, component.RootView) + if rootView == nil || rootView.Schema == nil || rootView.Mode != view.ModeExec { + return } - if ret.Meta == "" && ret.DefaultConnector == "" && ret.Cache == nil && ret.MCP == nil { - return nil + _ = view.WithTemplateParameterStateType(true)(rootView) + body := firstMutableBodyState(component.Input) + if body == nil || body.In == nil || body.Schema == nil { + return } - return ret -} + bodyName := strings.TrimSpace(body.Name) + if bodyName == "" { + return + } + helperViewName := "Cur" + bodyName + if hasInputState(component.Input, helperViewName) { + return + } + keyFieldName, keyColumnName, keyType := mutableKeyDescriptor(rootView, body.Schema) + if keyFieldName == "" || keyColumnName == "" || keyType == nil { + return + } + componentDir := text.CaseFormatUpperCamel.Format(strings.TrimSpace(componentRootName(component, rootView, bodyName)), text.CaseFormatLowerUnderscore) + if componentDir == "" { + componentDir = text.CaseFormatUpperCamel.Format(bodyName, text.CaseFormatLowerUnderscore) + } + helperIDsName := helperViewName + keyFieldName + helperViewURI := path.Join(componentDir, text.CaseFormatUpperCamel.Format(helperViewName, text.CaseFormatLowerUnderscore)+".sql") -func pickRootView(views []*plan.View) *plan.View { - var selected *plan.View - minDepth := -1 - for _, candidate := range views { - if candidate == nil || candidate.Path == "" { - continue - } - depth := strings.Count(candidate.Path, ".") - if minDepth == -1 || depth < minDepth { - minDepth = depth - selected = candidate - } + valuesType := reflect.StructOf([]reflect.StructField{{ + Name: "Values", + Type: reflect.SliceOf(keyType), + Tag: reflect.StructTag(`json:",omitempty"`), + }}) + helperIDsSchema := state.NewSchema(reflect.PtrTo(valuesType)) + if helperIDsSchema != nil && strings.TrimSpace(helperIDsSchema.DataType) == "" { + helperIDsSchema.DataType = loaderSchemaTypeExpr(reflect.PtrTo(valuesType)) } - if selected != nil { - return selected + helperSourceSchema := body.Schema + if helperSourceSchema == nil && rootView.Schema != nil { + helperSourceSchema = rootView.Schema.Clone() } - for _, candidate := range views { - if candidate != nil { - return candidate - } + if helperSourceSchema != nil { + helperSourceSchema = helperSourceSchema.Clone() } - return nil -} + helperIDsParam := &state.Parameter{ + Name: helperIDsName, + In: state.NewParameterLocation(bodyName), + Schema: helperSourceSchema, + Output: &state.Codec{Name: "structql", Body: fmt.Sprintf(" SELECT ARRAY_AGG(%s) AS Values FROM `/` LIMIT 1", keyFieldName), Schema: helperIDsSchema.Clone()}, + PreserveSchema: true, + } + resource.Parameters.Append(helperIDsParam) -func materializeView(item *plan.View) (*view.View, error) { - if item == nil { - return nil, fmt.Errorf("shape load: nil view plan item") + helperSchema := rootView.Schema.Clone() + helperSchema.Cardinality = state.Many + helperViewParamSchema := helperSchema.Clone() + if helperViewParamSchema.Cardinality == "" { + helperViewParamSchema.Cardinality = state.Many + } + helperViewParam := &state.Parameter{ + Name: helperViewName, + In: state.NewViewLocation(helperViewName), + Tag: fmt.Sprintf(`view:"%s" sql:"uri=%s"`, helperViewName, helperViewURI), + Schema: helperViewParamSchema, } + resource.Parameters.Append(helperViewParam) + bindViewTemplateParameters(rootView, []*state.Parameter{ + helperIDsParam, + helperViewParam, + }) - schemaType := bestSchemaType(item) - if schemaType == nil { - return nil, fmt.Errorf("shape load: missing schema type for view %q", item.Name) + helperView := view.NewView(helperViewName, "", view.WithSchema(helperSchema.Clone()), view.WithMode(view.ModeQuery)) + helperView.Table = rootView.Table + helperView.Connector = rootView.Connector + helperView.Columns = cloneViewColumns(rootView.Columns) + helperView.ColumnsConfig = cloneViewColumnsConfig(rootView.ColumnsConfig) + helperView.Selector = &view.Config{ + Namespace: strings.ToLower(truncateString(helperViewName, 2)), + Limit: 1000, + Constraints: &view.Constraints{ + Criteria: true, + Limit: true, + Offset: true, + Projection: true, + }, } + helperView.Template = view.NewTemplate( + fmt.Sprintf("SELECT * FROM %s\nWHERE $criteria.In(%q, $Unsafe.%s.Values)", rootView.Table, keyColumnName, helperIDsName), + view.WithTemplateParameters(helperIDsParam), + view.WithTemplateUnsafeStateFromParameters(true), + view.WithTemplateDeclaredParametersOnly(true), + view.WithTemplateResourceParameterLookup(true), + ) + helperView.Template.SourceURL = helperViewURI + resource.AddViews(helperView) + component.Views = append(component.Views, helperViewName) + synthesizeMutableRootTemplate(component, rootView, body, keyFieldName, helperViewName) +} - schema := newSchema(schemaType, item.Cardinality) - mode := view.ModeQuery - switch strings.TrimSpace(item.Mode) { - case string(view.ModeExec): - mode = view.ModeExec - case string(view.ModeHandler): - mode = view.ModeHandler - case string(view.ModeQuery): - mode = view.ModeQuery +func componentTemplateType(component *Component) string { + if component == nil || component.Directives == nil { + return "" } - opts := []view.Option{view.WithSchema(schema), view.WithMode(mode)} + return strings.TrimSpace(component.Directives.TemplateType) +} - if item.Connector != "" { - opts = append(opts, view.WithConnectorRef(item.Connector)) +func synthesizeMutableRootTemplate(component *Component, rootView *view.View, body *plan.State, keyFieldName string, helperViewName string) { + if component == nil || rootView == nil || body == nil || body.Schema == nil { + return } - if item.SQL != "" || item.SQLURI != "" { - tmpl := view.NewTemplate(item.SQL) - tmpl.SourceURL = item.SQLURI - if strings.TrimSpace(item.Summary) != "" { - tmpl.Summary = &view.TemplateSummary{ - Name: "Summary", - Source: item.Summary, - Kind: view.MetaKindRecord, - } - } - opts = append(opts, view.WithTemplate(tmpl)) + method := strings.ToUpper(strings.TrimSpace(component.Method)) + switch method { + case "PATCH", "POST", "PUT": + default: + return } - if item.CacheRef != "" { - opts = append(opts, view.WithCache(&view.Cache{Reference: shared.Reference{Ref: item.CacheRef}})) + bodyName := strings.TrimSpace(body.Name) + if bodyName == "" || strings.TrimSpace(rootView.Table) == "" || strings.TrimSpace(keyFieldName) == "" { + return } - if item.Partitioner != "" { - opts = append(opts, view.WithPartitioned(&view.Partitioned{ - DataType: item.Partitioner, - Concurrency: item.PartitionedConcurrency, - })) + if rootView.Template == nil { + rootView.Template = view.NewTemplate("", view.WithTemplateParameters()) + } + if rootView.TableBatches == nil { + rootView.TableBatches = map[string]bool{} } + rootView.TableBatches[rootView.Table] = true + rootView.Template.Source = buildMutableRootTemplate(method, rootView.Table, bodyName, keyFieldName, helperViewName, body.Schema.Cardinality == state.Many) +} - aView, err := view.New(item.Name, item.Table, opts...) - if err != nil { - return nil, err +func buildMutableRootTemplate(method string, tableName string, bodyName string, keyFieldName string, helperViewName string, many bool) string { + var builder strings.Builder + if strings.ToUpper(strings.TrimSpace(method)) != "PUT" { + builder.WriteString(fmt.Sprintf("$sequencer.Allocate(%q, $Unsafe.%s, %q)\n\n", tableName, bodyName, keyFieldName)) } - aView.Ref = item.Ref - aView.Module = item.Module - aView.AllowNulls = item.AllowNulls - // Gap 6: forward view-level tag from declaration. - if item.Declaration != nil && strings.TrimSpace(item.Declaration.Tag) != "" { - aView.Tag = strings.TrimSpace(item.Declaration.Tag) + mapName := helperViewName + "By" + keyFieldName + builder.WriteString(fmt.Sprintf("#set($%s = $Unsafe.%s.IndexBy(%q))\n\n", mapName, helperViewName, keyFieldName)) + if many { + recordVar := "Rec" + bodyName + builder.WriteString(fmt.Sprintf("#foreach($%s in $Unsafe.%s)\n", recordVar, bodyName)) + builder.WriteString(fmt.Sprintf(" #if($%s.HasKey($%s.%s) == true)\n", mapName, recordVar, keyFieldName)) + builder.WriteString(fmt.Sprintf("$sql.Update($%s, %q);\n", recordVar, tableName)) + builder.WriteString(" #else\n") + builder.WriteString(fmt.Sprintf("$sql.Insert($%s, %q);\n", recordVar, tableName)) + builder.WriteString(" #end\n") + builder.WriteString("#end") + return builder.String() } - if strings.TrimSpace(item.SelectorNamespace) != "" || item.SelectorNoLimit != nil { - if aView.Selector == nil { - aView.Selector = &view.Config{} + builder.WriteString(fmt.Sprintf("#if($Unsafe.%s)\n", bodyName)) + builder.WriteString(fmt.Sprintf(" #if($%s.HasKey($Unsafe.%s.%s) == true)\n", mapName, bodyName, keyFieldName)) + builder.WriteString(fmt.Sprintf("$sql.Update($Unsafe.%s, %q);\n", bodyName, tableName)) + builder.WriteString(" #else\n") + builder.WriteString(fmt.Sprintf("$sql.Insert($Unsafe.%s, %q);\n", bodyName, tableName)) + builder.WriteString(" #end\n") + builder.WriteString("#end") + return builder.String() +} + +func loaderSchemaTypeExpr(rType reflect.Type) string { + if rType == nil { + return "" + } + switch rType.Kind() { + case reflect.Ptr: + return "*" + loaderSchemaTypeExpr(rType.Elem()) + case reflect.Slice: + return "[]" + loaderSchemaTypeExpr(rType.Elem()) + case reflect.Array: + return fmt.Sprintf("[%d]%s", rType.Len(), loaderSchemaTypeExpr(rType.Elem())) + case reflect.Map: + return "map[" + loaderSchemaTypeExpr(rType.Key()) + "]" + loaderSchemaTypeExpr(rType.Elem()) + default: + return rType.String() + } +} + +func firstMutableBodyState(states []*plan.State) *plan.State { + for _, item := range states { + if item == nil || item.In == nil || item.In.Kind != state.KindRequestBody { + continue } - if strings.TrimSpace(item.SelectorNamespace) != "" { - aView.Selector.Namespace = strings.TrimSpace(item.SelectorNamespace) + if !item.IsAnonymous() { + continue } - if item.SelectorNoLimit != nil { - aView.Selector.NoLimit = *item.SelectorNoLimit + return item + } + return nil +} + +func hasInputState(states []*plan.State, name string) bool { + for _, item := range states { + if item == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(item.Name), strings.TrimSpace(name)) { + return true } } - if aView.Schema != nil && strings.TrimSpace(item.SchemaType) != "" { - if aView.Schema.DataType == "" { - aView.Schema.DataType = strings.TrimSpace(item.SchemaType) + return false +} + +func componentRootName(component *Component, rootView *view.View, fallback string) string { + if rootView != nil && strings.TrimSpace(rootView.Name) != "" { + return rootView.Name + } + if component != nil && strings.TrimSpace(component.RootView) != "" { + return component.RootView + } + return fallback +} + +func mutableKeyDescriptor(rootView *view.View, schema *state.Schema) (string, string, reflect.Type) { + if fieldName, columnName, rType := mutableKeyFromType(schema); fieldName != "" && columnName != "" && rType != nil { + return fieldName, columnName, rType + } + if rootView == nil { + return "", "", nil + } + for _, column := range rootView.Columns { + if column == nil { + continue } - if aView.Schema.Name == "" { - aView.Schema.Name = strings.Trim(strings.TrimSpace(item.SchemaType), "*") + dbColumn := strings.TrimSpace(column.DatabaseColumn) + if dbColumn == "" { + dbColumn = strings.TrimSpace(column.Name) + } + if !strings.EqualFold(dbColumn, "ID") { + continue + } + fieldName := strings.TrimSpace(column.FieldName()) + if fieldName == "" { + fieldName = "Id" + } + switch strings.ToLower(strings.TrimSpace(column.DataType)) { + case "int", "integer", "bigint", "smallint": + return fieldName, dbColumn, reflect.TypeOf(0) } } - // Populate columns from statically-inferred struct type so that xgen can - // generate accurate Go struct definitions during bootstrap. Only applied when - // the view has no columns yet (avoids overwriting explicit column config). - if len(aView.Columns) == 0 { - if cols := inferColumnsFromType(item.ElementType); len(cols) > 0 { - aView.Columns = cols + return "", "", nil +} + +func mutableKeyFromType(schema *state.Schema) (string, string, reflect.Type) { + if schema == nil || schema.Type() == nil { + return "", "", nil + } + rType := schema.Type() + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return "", "", nil + } + if field, ok := rType.FieldByName("Id"); ok { + return "Id", "ID", derefType(field.Type) + } + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + sqlxTag := field.Tag.Get("sqlx") + if sqlxTag == "ID" || strings.Contains(sqlxTag, "name=ID") { + return field.Name, "ID", derefType(field.Type) } } - return aView, nil + return "", "", nil } -func bestSchemaType(item *plan.View) reflect.Type { - if item.FieldType != nil { - return item.FieldType +func derefType(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() } - if item.ElementType != nil { - return item.ElementType + return rType +} + +func truncateString(value string, max int) string { + value = strings.TrimSpace(value) + if max <= 0 || len(value) <= max { + return value } - return nil + return value[:max] } -func toViewRelations(input []*plan.Relation) []*view.Relation { - if len(input) == 0 { +func cloneViewColumns(columns []*view.Column) []*view.Column { + if len(columns) == 0 { return nil } - result := make([]*view.Relation, 0, len(input)) - for _, item := range input { + result := make([]*view.Column, 0, len(columns)) + for _, item := range columns { if item == nil { continue } - relation := &view.Relation{ - Name: item.Name, - Holder: item.Holder, - On: toViewLinks(item.On, true), - Of: view.NewReferenceView( - toViewLinks(item.On, false), - view.NewView(item.Ref, item.Table), - ), - } - result = append(result, relation) + cloned := *item + result = append(result, &cloned) } return result } -func toViewLinks(input []*plan.RelationLink, parent bool) view.Links { - if len(input) == 0 { +func cloneViewColumnsConfig(columns map[string]*view.ColumnConfig) map[string]*view.ColumnConfig { + if len(columns) == 0 { return nil } - result := make(view.Links, 0, len(input)) - for _, item := range input { + result := make(map[string]*view.ColumnConfig, len(columns)) + for key, item := range columns { if item == nil { continue } - link := &view.Link{} - if parent { - link.Field = item.ParentField - link.Namespace = item.ParentNamespace - link.Column = item.ParentColumn - } else { - link.Field = item.RefField - link.Namespace = item.RefNamespace - link.Column = item.RefColumn - } - result = append(result, link) + cloned := *item + result[key] = &cloned } return result } -func newSchema(rType reflect.Type, cardinality string) *state.Schema { - if cardinality == "many" && rType.Kind() != reflect.Slice { - return state.NewSchema(rType, state.WithMany()) +func lookupNamedResourceView(resource *view.Resource, name string) *view.View { + if resource == nil { + return nil + } + if strings.TrimSpace(name) != "" { + for _, item := range resource.Views { + if item != nil && strings.EqualFold(strings.TrimSpace(item.Name), strings.TrimSpace(name)) { + return item + } + } + } + for _, item := range resource.Views { + if item != nil { + return item + } + } + return nil +} + +func firstComponentRoute(routes []*plan.ComponentRoute) *plan.ComponentRoute { + for _, item := range routes { + if item != nil { + return item + } + } + return nil +} + +func materializeComponentRouteView(source *shape.Source, pResult *plan.Result, resource *view.Resource) error { + route := firstComponentRoute(pResult.Components) + if route == nil || resource == nil { + return nil + } + viewType := resolveRouteViewType(source, route) + if viewType == nil { + return nil + } + viewName := routeViewAlias(route) + if viewName == "" { + viewName = "View" + } + opts := []view.Option{ + view.WithSchema(state.NewSchema(viewType)), + view.WithMode(componentRouteMode(route)), + } + if connectorRef := strings.TrimSpace(route.Connector); connectorRef != "" { + opts = append(opts, view.WithConnectorRef(connectorRef)) + } + rootView := view.NewView(viewName, "", opts...) + if sourceURL := absoluteRouteSourceURL(source, route); sourceURL != "" { + tmpl := view.NewTemplate("") + tmpl.SourceURL = sourceURL + rootView.Template = tmpl + } + resource.AddViews(rootView) + return nil +} + +func componentRouteMode(route *plan.ComponentRoute) view.Mode { + if route == nil { + return view.ModeQuery + } + if strings.TrimSpace(route.Handler) != "" { + return view.ModeHandler + } + switch strings.ToUpper(strings.TrimSpace(route.Method)) { + case "", "GET": + return view.ModeQuery + default: + return view.ModeExec + } +} + +func resolveRouteViewType(source *shape.Source, route *plan.ComponentRoute) reflect.Type { + if source == nil || route == nil { + return nil + } + typeName := strings.TrimSpace(route.ViewName) + if typeName == "" { + return nil + } + registry := source.EnsureTypeRegistry() + if registry == nil { + return nil + } + if lookup := registry.Lookup(typeName); lookup != nil && lookup.Type != nil { + return lookup.Type + } + resolver := typectx.NewResolver(registry, nil) + if resolved, err := resolver.Resolve(typeName); err == nil && resolved != "" { + if lookup := registry.Lookup(resolved); lookup != nil && lookup.Type != nil { + return lookup.Type + } + } + return nil +} + +func routeViewAlias(route *plan.ComponentRoute) string { + if route == nil { + return "" + } + if name := strings.TrimSpace(route.Name); name != "" { + return name + } + viewName := strings.TrimSpace(route.ViewName) + if index := strings.LastIndex(viewName, "."); index >= 0 { + viewName = viewName[index+1:] + } + viewName = strings.TrimSuffix(viewName, "View") + return strings.TrimSpace(viewName) +} + +func absoluteRouteSourceURL(source *shape.Source, route *plan.ComponentRoute) string { + if route == nil { + return "" + } + sourceURL := strings.TrimSpace(route.SourceURL) + return absolutizeRouteAssetURL(source, sourceURL) +} + +func absoluteRouteSummaryURL(source *shape.Source, route *plan.ComponentRoute) string { + if route == nil { + return "" + } + return absolutizeRouteAssetURL(source, strings.TrimSpace(route.SummaryURL)) +} + +func absolutizeRouteAssetURL(source *shape.Source, sourceURL string) string { + if sourceURL == "" || strings.Contains(sourceURL, "://") { + return sourceURL + } + if filepath.IsAbs(sourceURL) { + return sourceURL + } + baseDir := "" + if source != nil { + baseDir = source.BaseDir() + } + if baseDir == "" { + return sourceURL + } + return filepath.Join(baseDir, filepath.FromSlash(sourceURL)) +} + +func resolveTypeSpecs(pResult *plan.Result) map[string]*TypeSpec { + if pResult == nil { + return nil + } + specs := map[string]*TypeSpec{} + directives := pResult.Directives + if directives != nil { + if typeName := strings.TrimSpace(directives.InputType); typeName != "" { + specs["input"] = &TypeSpec{Key: "input", Role: TypeRoleInput, TypeName: typeName, Source: "directive"} + } + if typeName := strings.TrimSpace(directives.OutputType); typeName != "" { + specs["output"] = &TypeSpec{Key: "output", Role: TypeRoleOutput, TypeName: typeName, Source: "directive"} + } + if dest := strings.TrimSpace(directives.InputDest); dest != "" { + spec := ensureTypeSpec(specs, "input", TypeRoleInput) + spec.Dest = dest + spec.Source = "directive" + } + if dest := strings.TrimSpace(directives.OutputDest); dest != "" { + spec := ensureTypeSpec(specs, "output", TypeRoleOutput) + spec.Dest = dest + spec.Source = "directive" + } + } + globalDest := "" + if directives != nil { + globalDest = strings.TrimSpace(directives.Dest) + } + for _, aView := range pResult.Views { + if aView == nil || strings.TrimSpace(aView.Name) == "" { + continue + } + key := "view:" + aView.Name + spec := ensureTypeSpec(specs, key, TypeRoleView) + spec.Alias = aView.Name + if globalDest != "" && spec.Dest == "" { + spec.Dest = globalDest + spec.Inherited = true + spec.Source = "directive" + } + if aView.Declaration != nil { + if typeName := strings.TrimSpace(aView.Declaration.TypeName); typeName != "" { + spec.TypeName = typeName + spec.Source = "decl" + } + if dest := strings.TrimSpace(aView.Declaration.Dest); dest != "" { + spec.Dest = dest + spec.Inherited = false + spec.Source = "decl" + } + if tagType, tagDest := parseTypeSpecTag(aView.Declaration.Tag); tagType != "" || tagDest != "" { + if spec.TypeName == "" && tagType != "" { + spec.TypeName = tagType + spec.Source = "annotation" + } + if spec.Dest == "" && tagDest != "" { + spec.Dest = tagDest + spec.Inherited = false + spec.Source = "annotation" + } + } + } + } + if root := pickRootView(pResult.Views); root != nil { + if rootSpec := specs["view:"+root.Name]; rootSpec != nil && strings.TrimSpace(rootSpec.Dest) != "" { + rootDest := strings.TrimSpace(rootSpec.Dest) + for _, aView := range pResult.Views { + if aView == nil || strings.TrimSpace(aView.Name) == "" || aView.Name == root.Name { + continue + } + spec := ensureTypeSpec(specs, "view:"+aView.Name, TypeRoleView) + spec.Alias = aView.Name + if strings.TrimSpace(spec.Dest) == "" || spec.Source == "directive" || spec.Source == "inherit" { + spec.Dest = rootDest + spec.Inherited = true + spec.Source = "inherit" + } + } + } + } + if globalDest != "" { + inputSpec := ensureTypeSpec(specs, "input", TypeRoleInput) + if strings.TrimSpace(inputSpec.Dest) == "" { + inputSpec.Dest = globalDest + inputSpec.Inherited = true + if inputSpec.Source == "" { + inputSpec.Source = "directive" + } + } + outputSpec := ensureTypeSpec(specs, "output", TypeRoleOutput) + if strings.TrimSpace(outputSpec.Dest) == "" { + outputSpec.Dest = globalDest + outputSpec.Inherited = true + if outputSpec.Source == "" { + outputSpec.Source = "directive" + } + } + } + if len(specs) == 0 { + return nil + } + return specs +} + +func ensureTypeSpec(specs map[string]*TypeSpec, key string, role TypeRole) *TypeSpec { + if spec, ok := specs[key]; ok && spec != nil { + return spec + } + spec := &TypeSpec{Key: key, Role: role} + specs[key] = spec + return spec +} + +func parseTypeSpecTag(raw string) (string, string) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "" + } + var typeName, dest string + for _, part := range strings.Split(raw, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + key, value, ok := strings.Cut(part, "=") + if !ok { + continue + } + key = strings.ToLower(strings.TrimSpace(key)) + value = strings.TrimSpace(strings.Trim(value, `"'`)) + switch key { + case "type": + typeName = value + case "dest": + dest = value + } + } + return strings.TrimSpace(typeName), strings.TrimSpace(dest) +} + +// applyViewMeta populates the component with view names, declarations, relations, +// query selectors, predicate maps, and root view from the plan view list. +func applyViewMeta(component *Component, views []*plan.View) { + for _, aView := range views { + if aView == nil { + continue + } + component.Views = append(component.Views, aView.Name) + if aView.Declaration != nil { + indexViewDeclaration(component, declaredViewIndexName(aView), aView.Declaration) + } + if len(aView.Relations) > 0 { + component.Relations = append(component.Relations, aView.Relations...) + component.ViewRelations = append(component.ViewRelations, toViewRelations(aView.Relations)...) + } + } + if rootView := pickRootView(views); rootView != nil { + component.RootView = rootView.Name + if component.Name == "" { + component.Name = rootView.Name + } + } +} + +func declaredViewIndexName(aView *plan.View) string { + if aView == nil { + return "" + } + if queryNode, err := sqlparser.ParseQuery(strings.TrimSpace(aView.SQL)); err == nil && queryNode != nil { + if inferredName, _, err := pipeline.InferRoot(queryNode, aView.Name); err == nil && strings.TrimSpace(inferredName) != "" { + return inferredName + } + } + return aView.Name +} + +// indexViewDeclaration registers the declaration's query selector and predicates +// on the component index maps, creating them on demand. +func indexViewDeclaration(component *Component, viewName string, decl *plan.ViewDeclaration) { + if component.Declarations == nil { + component.Declarations = map[string]*plan.ViewDeclaration{} + } + component.Declarations[viewName] = decl + if selector := strings.TrimSpace(decl.QuerySelector); selector != "" { + if component.QuerySelectors == nil { + component.QuerySelectors = map[string][]string{} + } + component.QuerySelectors[selector] = append(component.QuerySelectors[selector], viewName) + } + if len(decl.Predicates) > 0 { + if component.Predicates == nil { + component.Predicates = map[string][]*plan.ViewPredicate{} + } + component.Predicates[viewName] = append(component.Predicates[viewName], decl.Predicates...) + } +} + +// applyStateBuckets sorts plan states into the typed buckets on the component +// (Input, Output, Meta, Async, Other) based on the state's location kind. +func applyStateBuckets(component *Component, states []*plan.State, resource *view.Resource, source *shape.Source, ctx *typectx.Context, loadOptions *shape.LoadOptions) { + for _, item := range states { + if item == nil { + continue + } + cloned := clonePlanState(item) + if cloned == nil { + continue + } + if loadOptions != nil && loadOptions.UseTypeContextPackages { + inheritTypeContextSchemaPackage(&cloned.Parameter, component) + } + if selector := strings.TrimSpace(cloned.QuerySelector); selector != "" { + if component.QuerySelectors == nil { + component.QuerySelectors = map[string][]string{} + } + component.QuerySelectors[selector] = append(component.QuerySelectors[selector], cloned.Name) + } + normalizeDerivedInputSchema(&cloned.Parameter, resource) + inheritRootBodySchema(&cloned.Parameter, rootResourceView(resource, nil)) + inheritRootOutputSchema(&cloned.Parameter, rootResourceView(resource, nil)) + ensureMaterializedOutputSchema(&cloned.Parameter, rootResourceView(resource, nil), source, ctx) + kind := state.Kind(strings.ToLower(item.KindString())) + inName := item.InName() + if kind == "" && inName == "" { + component.Other = append(component.Other, cloned) + continue + } + if cloned.EmitOutput && kind != state.KindOutput { + outputClone := clonePlanState(cloned) + if outputClone != nil { + component.Output = append(component.Output, outputClone) + } + } + switch kind { + case state.KindQuery, state.KindPath, state.KindHeader, state.KindRequestBody, + state.KindView, state.KindComponent, state.KindConst, + state.KindForm, state.KindCookie, state.KindRequest, "": + if kind == state.KindComponent { + normalizeDynamicComponentSchema(&cloned.Parameter) + } + component.Input = append(component.Input, cloned) + case state.KindOutput: + component.Output = append(component.Output, cloned) + case state.KindMeta: + component.Meta = append(component.Meta, cloned) + case state.KindAsync: + component.Async = append(component.Async, cloned) + default: + component.Other = append(component.Other, cloned) + } + } +} + +func normalizeDynamicComponentSchema(param *state.Parameter) { + if param == nil || param.Schema == nil { + return + } + param.Schema.SetType(reflect.TypeOf((*interface{})(nil)).Elem()) + param.Schema.Package = "" + param.Schema.PackagePath = "" +} + +func inheritTypeContextSchemaPackage(param *state.Parameter, component *Component) { + if param == nil || param.Schema == nil || component == nil || component.TypeContext == nil { + return + } + if param.Schema.Type() != nil { + return + } + if strings.TrimSpace(param.Schema.Package) != "" || strings.TrimSpace(param.Schema.PackagePath) != "" { + return + } + typeName := strings.TrimSpace(shared.FirstNotEmpty(param.Schema.Name, param.Schema.DataType)) + if !shouldInheritTypeContextPackage(typeName) { + return + } + if _, err := types.LookupType(nil, typeName); err == nil { + return + } + if baseType := schemaBaseTypeName(typeName); baseType != typeName { + if _, err := types.LookupType(nil, baseType); err == nil { + return + } + } + pkg, pkgPath := schemaTypeContextPackage(component.TypeContext) + if pkgPath == "" { + return + } + if pkg != "" { + param.Schema.Package = pkg + } + param.Schema.PackagePath = pkgPath +} + +func inheritViewSchemaPackage(aView *view.View, ctx *typectx.Context) { + if aView == nil || aView.Schema == nil || ctx == nil { + return + } + if aView.Schema.Type() != nil { + return + } + if strings.TrimSpace(aView.Schema.Package) != "" || strings.TrimSpace(aView.Schema.PackagePath) != "" { + return + } + typeName := strings.TrimSpace(shared.FirstNotEmpty(aView.Schema.Name, aView.Schema.DataType)) + if !shouldInheritTypeContextPackage(typeName) { + return + } + if _, err := types.LookupType(nil, typeName); err == nil { + return + } + if baseType := schemaBaseTypeName(typeName); baseType != typeName { + if _, err := types.LookupType(nil, baseType); err == nil { + return + } + } + pkg, pkgPath := schemaTypeContextPackage(ctx) + if pkgPath == "" { + return + } + if pkg != "" { + aView.Schema.Package = pkg + } + aView.Schema.PackagePath = pkgPath +} + +func schemaTypeContextPackage(ctx *typectx.Context) (string, string) { + if ctx == nil { + return "", "" + } + pkg := strings.TrimSpace(ctx.PackageName) + pkgPath := strings.TrimSpace(ctx.PackagePath) + if pkgPath == "" { + pkgPath = strings.TrimSpace(ctx.DefaultPackage) + } + if pkg == "" && pkgPath != "" { + pkg = path.Base(pkgPath) + } + return pkg, pkgPath +} + +func shouldInheritTypeContextPackage(typeName string) bool { + baseType := schemaBaseTypeName(typeName) + if baseType == "" { + return false + } + if strings.Contains(baseType, ".") { + return false + } + if builtinSchemaTypes[baseType] { + return false + } + return true +} + +func schemaBaseTypeName(typeName string) string { + typeName = strings.TrimSpace(typeName) + for { + switch { + case strings.HasPrefix(typeName, "[]"): + typeName = strings.TrimSpace(typeName[2:]) + case strings.HasPrefix(typeName, "*"): + typeName = strings.TrimSpace(typeName[1:]) + default: + goto done + } + } +done: + if typeName == "" { + return "" + } + if strings.ContainsAny(typeName, " {}[](),") { + return "" + } + return typeName +} + +var builtinSchemaTypes = map[string]bool{ + "any": true, + "bool": true, + "byte": true, + "complex128": true, + "complex64": true, + "error": true, + "float32": true, + "float64": true, + "int": true, + "int16": true, + "int32": true, + "int64": true, + "int8": true, + "interface{}": true, + "rune": true, + "string": true, + "uint": true, + "uint16": true, + "uint32": true, + "uint64": true, + "uint8": true, + "uintptr": true, +} + +func synthesizeMissingRouteContractStates(component *Component, routes []*plan.ComponentRoute) []*plan.State { + if component == nil || len(routes) == 0 { + return nil + } + declared := map[string]bool{} + register := func(items []*plan.State) { + for _, item := range items { + if item == nil || item.In == nil { + continue + } + declared[routeContractStateKey(item.Name, item.In.Kind, item.In.Name)] = true + } + } + register(component.Input) + register(component.Output) + register(component.Meta) + register(component.Async) + register(component.Other) + + var result []*plan.State + for _, route := range routes { + if route == nil { + continue + } + for _, item := range contractStates(route.InputType) { + key := routeContractStateKey(item.Name, item.In.Kind, item.In.Name) + if declared[key] { + continue + } + declared[key] = true + result = append(result, item) + } + for _, item := range contractStates(route.OutputType) { + key := routeContractStateKey(item.Name, item.In.Kind, item.In.Name) + if declared[key] { + continue + } + declared[key] = true + result = append(result, item) + } + } + return result +} + +func contractStates(rType reflect.Type) []*plan.State { + rType = unwrapContractStateType(rType) + if rType == nil || rType.Kind() != reflect.Struct { + return nil + } + var result []*plan.State + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if field.Anonymous { + result = append(result, contractStates(field.Type)...) + } + parsed, err := tags.ParseStateTags(field.Tag, nil) + if err != nil || parsed == nil || parsed.Parameter == nil { + continue + } + param := parsed.Parameter + name := strings.TrimSpace(param.Name) + if name == "" { + name = strings.TrimSpace(field.Name) + } + locationKind := state.Kind(strings.ToLower(strings.TrimSpace(param.Kind))) + locationName := strings.TrimSpace(param.In) + item := &plan.State{ + Parameter: state.Parameter{ + Name: name, + In: &state.Location{Kind: locationKind, Name: locationName}, + When: param.When, + Scope: param.Scope, + Required: param.Required, + Async: param.Async, + Cacheable: param.Cacheable, + With: param.With, + URI: param.URI, + ErrorStatusCode: param.ErrorCode, + ErrorMessage: param.ErrorMessage, + Tag: string(field.Tag), + Schema: state.NewSchema(field.Type), + }, + } + state.BuildCodec(parsed, &item.Parameter) + state.BuildHandler(parsed, &item.Parameter) + if dataType := strings.TrimSpace(param.DataType); dataType != "" && item.Schema != nil { + item.Schema.DataType = dataType + } + result = append(result, item) + } + return result +} + +func unwrapContractStateType(rType reflect.Type) reflect.Type { + for rType != nil && (rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array) { + rType = rType.Elem() + } + return rType +} + +func routeContractStateKey(name string, kind state.Kind, in string) string { + return strings.ToLower(strings.TrimSpace(name)) + "|" + strings.ToLower(strings.TrimSpace(string(kind))) + "|" + strings.ToLower(strings.TrimSpace(in)) +} + +// synthesizePredicateStates creates query parameters for view-level predicates whose +// source parameter is not already present in the input state list. +func synthesizePredicateStates(input []*plan.State, predicates map[string][]*plan.ViewPredicate) []*plan.State { + if len(predicates) == 0 { + return nil + } + declared := make(map[string]bool, len(input)) + for _, s := range input { + if s != nil { + declared[strings.ToLower(strings.TrimPrefix(strings.TrimSpace(s.Name), "$"))] = true + } + } + var result []*plan.State + for _, viewPredicates := range predicates { + for _, vp := range viewPredicates { + if vp == nil { + continue + } + src := strings.TrimPrefix(strings.TrimSpace(vp.Source), "$") + if src == "" || declared[strings.ToLower(src)] { + continue + } + result = append(result, &plan.State{ + Parameter: state.Parameter{ + Name: src, + In: state.NewQueryLocation(src), + Schema: &state.Schema{DataType: "string"}, + Predicates: []*extension.PredicateConfig{ + { + Name: vp.Name, + Ensure: vp.Ensure, + Args: append([]string{}, vp.Arguments...), + }, + }, + }, + }) + declared[strings.ToLower(src)] = true + } + } + return result +} + +func synthesizeConstStates(constants map[string]string) []*plan.State { + if len(constants) == 0 { + return nil + } + keys := make([]string, 0, len(constants)) + for key := range constants { + key = strings.TrimSpace(key) + if key != "" { + keys = append(keys, key) + } + } + sort.Strings(keys) + result := make([]*plan.State, 0, len(keys)) + for _, key := range keys { + result = append(result, &plan.State{ + Parameter: state.Parameter{ + Name: key, + In: state.NewConstLocation(key), + Value: constants[key], + Tag: `internal:"true"`, + Schema: &state.Schema{ + Name: "string", + DataType: "string", + Cardinality: state.One, + }, + }, + }) + } + return result +} + +func cloneTypeContext(input *typectx.Context) *typectx.Context { + if input == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + PackageDir: strings.TrimSpace(input.PackageDir), + PackageName: strings.TrimSpace(input.PackageName), + PackagePath: strings.TrimSpace(input.PackagePath), + } + for _, item := range input.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: strings.TrimSpace(item.Alias), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && + len(ret.Imports) == 0 && + ret.PackageDir == "" && + ret.PackageName == "" && + ret.PackagePath == "" { + return nil + } + return ret +} + +func cloneDirectives(input *dqlshape.Directives) *dqlshape.Directives { + if input == nil { + return nil + } + ret := &dqlshape.Directives{ + Meta: strings.TrimSpace(input.Meta), + DefaultConnector: strings.TrimSpace(input.DefaultConnector), + TemplateType: strings.TrimSpace(input.TemplateType), + Dest: strings.TrimSpace(input.Dest), + InputDest: strings.TrimSpace(input.InputDest), + OutputDest: strings.TrimSpace(input.OutputDest), + RouterDest: strings.TrimSpace(input.RouterDest), + InputType: strings.TrimSpace(input.InputType), + OutputType: strings.TrimSpace(input.OutputType), + } + if input.Cache != nil { + ret.Cache = &dqlshape.CacheDirective{ + Enabled: input.Cache.Enabled, + TTL: strings.TrimSpace(input.Cache.TTL), + Name: strings.TrimSpace(input.Cache.Name), + Provider: strings.TrimSpace(input.Cache.Provider), + Location: strings.TrimSpace(input.Cache.Location), + TimeToLiveMs: input.Cache.TimeToLiveMs, + } + } + if input.MCP != nil { + ret.MCP = &dqlshape.MCPDirective{ + Name: strings.TrimSpace(input.MCP.Name), + Description: strings.TrimSpace(input.MCP.Description), + DescriptionPath: strings.TrimSpace(input.MCP.DescriptionPath), + } + } + if input.Const != nil { + ret.Const = make(map[string]string, len(input.Const)) + for k, v := range input.Const { + ret.Const[k] = v + } + } + if input.Route != nil { + ret.Route = &dqlshape.RouteDirective{ + URI: strings.TrimSpace(input.Route.URI), + } + for _, m := range input.Route.Methods { + if m = strings.TrimSpace(m); m != "" { + ret.Route.Methods = append(ret.Route.Methods, m) + } + } + } + if input.Report != nil { + ret.Report = &dqlshape.ReportDirective{ + Enabled: input.Report.Enabled, + Input: strings.TrimSpace(input.Report.Input), + Dimensions: strings.TrimSpace(input.Report.Dimensions), + Measures: strings.TrimSpace(input.Report.Measures), + Filters: strings.TrimSpace(input.Report.Filters), + OrderBy: strings.TrimSpace(input.Report.OrderBy), + Limit: strings.TrimSpace(input.Report.Limit), + Offset: strings.TrimSpace(input.Report.Offset), + } + } + if ret.Meta == "" && ret.DefaultConnector == "" && ret.TemplateType == "" && + ret.Dest == "" && ret.InputDest == "" && ret.OutputDest == "" && ret.RouterDest == "" && + ret.InputType == "" && ret.OutputType == "" && + ret.Cache == nil && ret.MCP == nil && ret.Route == nil && ret.Report == nil && len(ret.Const) == 0 { + return nil + } + return ret +} + +func pickRootView(views []*plan.View) *plan.View { + var selected *plan.View + minDepth := -1 + for _, candidate := range views { + if candidate == nil || candidate.Path == "" { + continue + } + depth := strings.Count(candidate.Path, ".") + if minDepth == -1 || depth < minDepth { + minDepth = depth + selected = candidate + } + } + if selected != nil { + return selected + } + for _, candidate := range views { + if candidate != nil { + return candidate + } + } + return nil +} + +func materializeView(item *plan.View) (*view.View, error) { + if item == nil { + return nil, fmt.Errorf("shape load: nil view plan item") + } + + schemaType := bestSchemaType(item) + mode := view.ModeQuery + switch strings.TrimSpace(item.Mode) { + case string(view.ModeExec): + mode = view.ModeExec + case string(view.ModeHandler): + mode = view.ModeHandler + case string(view.ModeQuery): + mode = view.ModeQuery + } + if shouldDeferQuerySchemaType(schemaType, mode) { + schemaType = nil + } + if schemaType == nil && !allowsDeferredSchema(item, mode) { + return nil, fmt.Errorf("shape load: missing schema type for view %q", item.Name) + } + + schema := newSchema(schemaType, item.Cardinality) + opts := []view.Option{view.WithSchema(schema), view.WithMode(mode)} + if item.Groupable != nil { + opts = append(opts, view.WithGroupable(*item.Groupable)) + } + + if item.Connector != "" { + opts = append(opts, view.WithConnectorRef(item.Connector)) + } + if item.SQL != "" || item.SQLURI != "" { + tmpl := view.NewTemplate(item.SQL) + tmpl.SourceURL = item.SQLURI + opts = append(opts, view.WithTemplate(tmpl)) + } + if strings.TrimSpace(item.Summary) != "" || strings.TrimSpace(item.SummaryURL) != "" { + name := strings.TrimSpace(item.SummaryName) + if name == "" { + name = "Summary" + } + opts = append(opts, view.WithSummary(&view.TemplateSummary{ + Name: name, + Source: item.Summary, + SourceURL: item.SummaryURL, + Kind: view.MetaKindRecord, + })) + } + if item.CacheRef != "" { + opts = append(opts, view.WithCache(&view.Cache{Reference: shared.Reference{Ref: item.CacheRef}})) + } + if item.Partitioner != "" { + opts = append(opts, view.WithPartitioned(&view.Partitioned{ + DataType: item.Partitioner, + Concurrency: item.PartitionedConcurrency, + })) + } + + aView, err := view.New(item.Name, item.Table, opts...) + if err != nil { + return nil, err + } + aView.Ref = item.Ref + aView.Module = item.Module + aView.AllowNulls = item.AllowNulls + // Gap 6: forward view-level tag from declaration. + if item.Declaration != nil && strings.TrimSpace(item.Declaration.Tag) != "" { + aView.Tag = strings.TrimSpace(item.Declaration.Tag) + } + if item.Declaration != nil && len(item.Declaration.ColumnsConfig) > 0 { + if aView.ColumnsConfig == nil { + aView.ColumnsConfig = map[string]*view.ColumnConfig{} + } + for name, cfg := range item.Declaration.ColumnsConfig { + name = strings.TrimSpace(name) + if name == "" || cfg == nil { + continue + } + columnCfg := aView.ColumnsConfig[name] + if columnCfg == nil { + columnCfg = &view.ColumnConfig{Name: name} + aView.ColumnsConfig[name] = columnCfg + } + if dataType := strings.TrimSpace(cfg.DataType); dataType != "" { + columnCfg.DataType = stringPtr(dataType) + } + if tag := strings.TrimSpace(cfg.Tag); tag != "" { + columnCfg.Tag = stringPtr(tag) + } + if cfg.Groupable != nil { + columnCfg.Groupable = boolPtr(*cfg.Groupable) + } + } + } + if strings.TrimSpace(item.SelectorNamespace) != "" || item.SelectorNoLimit != nil || item.SelectorLimit != nil || + item.SelectorCriteria != nil || item.SelectorProjection != nil || item.SelectorOrderBy != nil || + item.SelectorOffset != nil || item.SelectorPage != nil || len(item.SelectorFilterable) > 0 || + len(item.SelectorOrderByColumns) > 0 { + if aView.Selector == nil { + aView.Selector = &view.Config{} + } + if aView.Selector.Constraints == nil { + aView.Selector.Constraints = &view.Constraints{} + } + if strings.TrimSpace(item.SelectorNamespace) != "" { + aView.Selector.Namespace = strings.TrimSpace(item.SelectorNamespace) + } + if item.SelectorNoLimit != nil { + aView.Selector.NoLimit = *item.SelectorNoLimit + aView.Selector.Constraints.Limit = true + } + if item.SelectorLimit != nil { + aView.Selector.Limit = *item.SelectorLimit + aView.Selector.Constraints.Limit = true + } + if item.SelectorCriteria != nil || item.SelectorProjection != nil || item.SelectorOrderBy != nil || + item.SelectorOffset != nil || item.SelectorPage != nil || len(item.SelectorFilterable) > 0 || + len(item.SelectorOrderByColumns) > 0 { + if item.SelectorCriteria != nil { + aView.Selector.Constraints.Criteria = *item.SelectorCriteria + } + if item.SelectorProjection != nil { + aView.Selector.Constraints.Projection = *item.SelectorProjection + } + if item.SelectorOrderBy != nil { + aView.Selector.Constraints.OrderBy = *item.SelectorOrderBy + } + if item.SelectorOffset != nil { + aView.Selector.Constraints.Offset = *item.SelectorOffset + } + if item.SelectorPage != nil { + value := *item.SelectorPage + aView.Selector.Constraints.Page = &value + } + if len(item.SelectorFilterable) > 0 { + aView.Selector.Constraints.Filterable = append([]string(nil), item.SelectorFilterable...) + } + if len(item.SelectorOrderByColumns) > 0 { + aView.Selector.Constraints.OrderByColumn = map[string]string{} + for key, value := range item.SelectorOrderByColumns { + aView.Selector.Constraints.OrderByColumn[key] = value + } + } + } + } + if item.Self != nil { + aView.SelfReference = &view.SelfReference{ + Holder: item.Self.Holder, + Child: item.Self.Child, + Parent: item.Self.Parent, + } + } + if aView.Schema != nil && strings.TrimSpace(item.SchemaType) != "" { + if aView.Schema.DataType == "" { + aView.Schema.DataType = strings.TrimSpace(item.SchemaType) + } + if aView.Schema.Name == "" { + aView.Schema.Name = strings.Trim(strings.TrimSpace(item.SchemaType), "*") + } + } + // Populate columns from statically-inferred struct type so that xgen can + // generate accurate Go struct definitions during bootstrap. Only applied when + // the view has no columns yet (avoids overwriting explicit column config). + if len(aView.Columns) == 0 { + if cols := inferColumnsFromType(bestSchemaType(item)); len(cols) > 0 && !inferredColumnsArePlaceholders(cols) { + aView.Columns = cols + } + } + if aView.Schema != nil && aView.Schema.Type() == nil { + if rowType := synthesizeViewSchemaType(aView); rowType != nil { + aView.Schema.SetType(rowType) + aView.Schema.EnsurePointer() + } + } + return aView, nil +} + +func assignViewSummarySchemas(resource *view.Resource, pResult *plan.Result, source *shape.Source) { + if resource == nil || pResult == nil { + return + } + index := resource.Views.Index() + for _, item := range pResult.Views { + if item == nil || strings.TrimSpace(item.Summary) == "" { + continue + } + aView, err := index.Lookup(item.Name) + if err != nil || aView == nil || aView.Template == nil || aView.Template.Summary == nil { + continue + } + if schema := aView.Template.Summary.Schema; schema != nil && (schema.Type() != nil || (strings.TrimSpace(schema.DataType) != "" && strings.TrimSpace(schema.DataType) != "?")) { + continue + } + summaryType := resolveSummarySchemaType(source, pResult.TypeContext, item.SummaryName) + if summaryType == nil { + summaryType = inferSummarySchemaType(item) + } + if summaryType == nil { + continue + } + aView.Template.Summary.Schema = materializedSummarySchema(summaryType, item.SummaryName, pResult.TypeContext) + } +} + +func inferSummarySchemaType(item *plan.View) reflect.Type { + if item == nil { + return nil + } + summarySQL := strings.TrimSpace(item.Summary) + if summarySQL == "" { + return nil + } + queryNode, _, err := pipeline.ParseSelectWithDiagnostic(pipeline.NormalizeParserSQL(summarySQL)) + if err != nil || queryNode == nil { + return nil + } + _, elementType, _ := pipeline.InferProjectionType(queryNode) + return unwrapSummarySchemaType(elementType) +} + +func unwrapSummarySchemaType(rType reflect.Type) reflect.Type { + for rType != nil { + switch rType.Kind() { + case reflect.Slice, reflect.Array, reflect.Ptr: + rType = rType.Elem() + default: + return rType + } + } + return nil +} + +func resolveSummarySchemaType(source *shape.Source, ctx *typectx.Context, summaryName string) reflect.Type { + summaryName = strings.TrimSpace(summaryName) + if summaryName == "" || source == nil { + return nil + } + registry := source.EnsureTypeRegistry() + if registry == nil { + return nil + } + candidates := []string{summaryName} + if !strings.HasSuffix(summaryName, "View") { + candidates = append([]string{summaryName + "View"}, candidates...) + } + resolver := typectx.NewResolver(registry, ctx) + for _, candidate := range candidates { + if lookup := registry.Lookup(candidate); lookup != nil && lookup.Type != nil { + return lookup.Type + } + if resolved, err := resolver.Resolve(candidate); err == nil && resolved != "" { + if lookup := registry.Lookup(resolved); lookup != nil && lookup.Type != nil { + return lookup.Type + } + } + } + return nil +} + +func resolveViewSchemaType(source *shape.Source, ctx *typectx.Context, aView *view.View, typeName string) reflect.Type { + candidates := []string{strings.TrimSpace(typeName)} + if aView != nil && aView.Schema != nil { + if name := strings.TrimSpace(aView.Schema.Name); name != "" { + candidates = append([]string{name}, candidates...) + } + } + seen := map[string]bool{} + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" || seen[candidate] { + continue + } + seen[candidate] = true + if astType := resolveViewSchemaASTType(source, ctx, aView, candidate); astType != nil { + return astType + } + if source != nil { + registry := source.EnsureTypeRegistry() + if registry != nil { + resolver := typectx.NewResolver(registry, ctx) + if lookup := registry.Lookup(candidate); lookup != nil && lookup.Type != nil { + return lookup.Type + } + if resolved, err := resolver.Resolve(candidate); err == nil && resolved != "" { + if lookup := registry.Lookup(resolved); lookup != nil && lookup.Type != nil { + return lookup.Type + } + } + } + } + } + return nil +} + +func materializeConcreteViewSchemas(resource *view.Resource, source *shape.Source, ctx *typectx.Context) { + if resource == nil { + return + } + visited := map[*view.View]bool{} + for _, aView := range resource.Views { + applyConcreteViewSchemaType(aView, source, ctx, visited) + } +} + +func enrichConcreteViewColumns(resource *view.Resource) { + if resource == nil { + return + } + visited := map[*view.View]bool{} + for _, aView := range resource.Views { + enrichViewColumnsFromSchema(aView, visited) + } +} + +func enrichViewColumnsFromSchema(aView *view.View, visited map[*view.View]bool) { + if aView == nil || visited[aView] { + return + } + visited[aView] = true + appendMissingColumnsFromSchema(aView) + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + enrichViewColumnsFromSchema(&rel.Of.View, visited) + } +} + +func refineViewColumnConfigTypes(resource *view.Resource, source *shape.Source, ctx *typectx.Context) { + if resource == nil { + return + } + visited := map[*view.View]bool{} + for _, aView := range resource.Views { + refineViewColumnConfigType(aView, source, ctx, visited) + } +} + +func refineViewColumnConfigType(aView *view.View, source *shape.Source, ctx *typectx.Context, visited map[*view.View]bool) { + if aView == nil || visited[aView] { + return + } + visited[aView] = true + applyConfiguredColumnTypes(aView, source, ctx) + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + refineViewColumnConfigType(&rel.Of.View, source, ctx, visited) + } +} + +func applyConfiguredColumnTypes(aView *view.View, source *shape.Source, ctx *typectx.Context) { + if aView == nil || len(aView.ColumnsConfig) == 0 { + return + } + if aView.Schema != nil && aView.Schema.Type() != nil { + if refined := refineSchemaTypeByColumnConfig(aView.Schema.Type(), aView.ColumnsConfig, source, ctx); refined != nil && refined != aView.Schema.Type() { + aView.Schema.SetType(refined) + aView.Schema.EnsurePointer() + } + } + for _, column := range aView.Columns { + if column == nil { + continue + } + cfg := lookupColumnConfig(aView.ColumnsConfig, column.Name, column.DatabaseColumn, column.FieldName()) + if cfg == nil || strings.TrimSpace(valueOrEmpty(cfg.DataType)) == "" { + continue + } + if resolved := resolveColumnConfigType(strings.TrimSpace(*cfg.DataType), source, ctx); resolved != nil { + column.DataType = strings.TrimSpace(*cfg.DataType) + column.SetColumnType(resolved) + } + } +} + +func refineSchemaTypeByColumnConfig(rType reflect.Type, configs map[string]*view.ColumnConfig, source *shape.Source, ctx *typectx.Context) reflect.Type { + if rType == nil { + return nil + } + switch rType.Kind() { + case reflect.Ptr: + if refined := refineSchemaTypeByColumnConfig(rType.Elem(), configs, source, ctx); refined != nil && refined != rType.Elem() { + return reflect.PtrTo(refined) + } + return rType + case reflect.Slice: + if refined := refineSchemaTypeByColumnConfig(rType.Elem(), configs, source, ctx); refined != nil && refined != rType.Elem() { + return reflect.SliceOf(refined) + } + return rType + case reflect.Array: + if refined := refineSchemaTypeByColumnConfig(rType.Elem(), configs, source, ctx); refined != nil && refined != rType.Elem() { + return reflect.ArrayOf(rType.Len(), refined) + } + return rType + case reflect.Struct: + fields := make([]reflect.StructField, 0, rType.NumField()) + changed := false + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + cfg := lookupColumnConfig(configs, field.Name, summaryLookupName(field)) + if cfg != nil && strings.TrimSpace(valueOrEmpty(cfg.DataType)) != "" { + if resolved := resolveColumnConfigType(strings.TrimSpace(*cfg.DataType), source, ctx); resolved != nil && resolved != field.Type { + field.Type = resolved + changed = true + } + } + fields = append(fields, field) + } + if changed { + return reflect.StructOf(fields) + } + } + return rType +} + +func lookupColumnConfig(configs map[string]*view.ColumnConfig, names ...string) *view.ColumnConfig { + if len(configs) == 0 { + return nil + } + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if cfg := configs[name]; cfg != nil { + return cfg + } + for key, cfg := range configs { + if strings.EqualFold(strings.TrimSpace(key), name) { + return cfg + } + } + } + return nil +} + +func valueOrEmpty(value *string) string { + if value == nil { + return "" + } + return *value +} + +func resolveColumnConfigType(dataType string, source *shape.Source, ctx *typectx.Context) reflect.Type { + dataType = strings.TrimSpace(dataType) + if dataType == "" { + return nil + } + if resolved, err := types.LookupType(extension.Config.Types.Lookup, dataType); err == nil && resolved != nil { + return resolved + } + if source == nil { + return nil + } + registry := source.EnsureTypeRegistry() + if registry == nil { + return nil + } + resolver := typectx.NewResolver(registry, ctx) + if lookup := registry.Lookup(dataType); lookup != nil && lookup.Type != nil { + return lookup.Type + } + if resolved, err := resolver.Resolve(dataType); err == nil && resolved != "" { + if lookup := registry.Lookup(resolved); lookup != nil && lookup.Type != nil { + return lookup.Type + } + } + return nil +} + +func appendMissingColumnsFromSchema(aView *view.View) { + if aView == nil || aView.Schema == nil || aView.Schema.Type() == nil { + return + } + structType := types.EnsureStruct(aView.Schema.Type()) + if structType == nil || structType.Kind() != reflect.Struct { + return + } + ioColumns, err := sqlxio.StructColumns(structType, "sqlx") + if err != nil || len(ioColumns) == 0 { + return + } + type columnMeta struct { + dataType string + nullable bool + } + metadata := map[string]columnMeta{} + for _, ioColumn := range ioColumns { + if ioColumn == nil { + continue + } + meta := columnMeta{dataType: columnDataTypeFromScanType(ioColumn.ScanType())} + meta.nullable, _ = ioColumn.Nullable() + tagName := "" + if tag := ioColumn.Tag(); tag != nil { + tagName = strings.TrimSpace(tag.Name()) + } + for _, key := range []string{ + strings.ToUpper(strings.TrimSpace(ioColumn.Name())), + strings.ToUpper(tagName), + } { + if key != "" { + metadata[key] = meta + } + } + } + for _, column := range aView.Columns { + if column == nil || strings.TrimSpace(column.DataType) != "" { + continue + } + for _, key := range []string{ + strings.ToUpper(strings.TrimSpace(column.Name)), + strings.ToUpper(strings.TrimSpace(column.DatabaseColumn)), + strings.ToUpper(strings.TrimSpace(column.FieldName())), + } { + if meta, ok := metadata[key]; ok { + if meta.dataType != "" { + column.DataType = meta.dataType + } + column.Nullable = meta.nullable + break + } + } + } + existing := map[string]bool{} + for _, column := range aView.Columns { + if column == nil { + continue + } + for _, key := range []string{ + strings.ToUpper(strings.TrimSpace(column.Name)), + strings.ToUpper(strings.TrimSpace(column.DatabaseColumn)), + strings.ToUpper(strings.TrimSpace(column.FieldName())), + } { + if key != "" { + existing[key] = true + } + } + } + for _, ioColumn := range ioColumns { + name := strings.TrimSpace(ioColumn.Name()) + if name == "" || existing[strings.ToUpper(name)] { + continue + } + tagValue := "" + if tag := ioColumn.Tag(); tag != nil { + tagValue = tag.Raw + if tag.Ns != "" { + if strings.HasSuffix(tagValue, `"`) { + tagValue = strings.TrimRight(tagValue, `"`) + ",ns=" + tag.Ns + `"` + } else { + tagValue += `",ns=` + tag.Ns + `"` + } + } + } + nullable, _ := ioColumn.Nullable() + column := view.NewColumn(name, ioColumn.DatabaseTypeName(), ioColumn.ScanType(), nullable, view.WithColumnTag(tagValue)) + if stateTag, _ := tags.ParseStateTags(reflect.StructTag(column.Tag), nil); stateTag != nil { + if stateTag.Format != nil { + column.FormatTag = stateTag.Format + } + if codec := stateTag.Codec; codec != nil { + column.Codec = &state.Codec{Name: codec.Name, Args: codec.Arguments} + } + } + aView.Columns = append(aView.Columns, column) + existing[strings.ToUpper(name)] = true + if dbName := strings.ToUpper(strings.TrimSpace(column.DatabaseColumn)); dbName != "" { + existing[dbName] = true + } + } +} + +func columnDataTypeFromScanType(scanType reflect.Type) string { + if scanType == nil { + return "" + } + if schema := schemaFromReflectType(scanType); schema != nil { + return strings.TrimSpace(schema.DataType) + } + return strings.TrimSpace(scanType.String()) +} + +func applyConcreteViewSchemaType(aView *view.View, source *shape.Source, ctx *typectx.Context, visited map[*view.View]bool) { + if aView == nil || visited[aView] { + return + } + visited[aView] = true + if aView.Schema != nil { + if resolved := resolveViewSchemaType(source, ctx, aView, relationTypeName(aView)); resolved != nil { + if resolved.Kind() != reflect.Ptr { + resolved = reflect.PtrTo(resolved) + } + aView.Schema.SetType(resolved) + aView.Schema.EnsurePointer() + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + applyConcreteViewSchemaType(&rel.Of.View, source, ctx, visited) + } +} + +func resolveViewSchemaASTType(source *shape.Source, ctx *typectx.Context, aView *view.View, typeName string) reflect.Type { + pkgDir := resolveViewSchemaPackageDir(source, ctx, aView) + if pkgDir == "" { + return nil + } + return parseNamedStructType(pkgDir, typeName) +} + +func resolveViewSchemaPackageDir(source *shape.Source, ctx *typectx.Context, aView *view.View) string { + if aView != nil && aView.Schema != nil { + if pkgPath := strings.TrimSpace(firstNonEmpty(aView.Schema.ModulePath, aView.Schema.PackagePath)); pkgPath != "" { + if dir := resolveTypePackageDirFromSource(pkgPath, ctx, source); dir != "" { + return dir + } + } + } + if ctx == nil { + return "" + } + if dir := strings.TrimSpace(ctx.PackageDir); dir != "" { + resolvedDir := dir + if filepath.IsAbs(dir) { + if isUsablePackageDir(dir) { + return dir + } + resolvedDir = dir + } else if moduleRoot := nearestModuleRoot(source); moduleRoot != "" { + resolvedDir = filepath.Join(moduleRoot, filepath.FromSlash(dir)) + if isUsablePackageDir(resolvedDir) { + return resolvedDir + } + } + } + if pkgPath := strings.TrimSpace(firstNonEmpty(ctx.PackagePath, ctx.DefaultPackage)); pkgPath != "" { + return resolveTypePackageDirFromSource(pkgPath, ctx, source) + } + return "" +} + +func isUsablePackageDir(dir string) bool { + if strings.TrimSpace(dir) == "" { + return false + } + info, err := os.Stat(dir) + return err == nil && info.IsDir() +} + +func resolveTypePackageDirFromSource(pkgPath string, ctx *typectx.Context, source *shape.Source) string { + if pkgPath == "" { + return "" + } + moduleRoot := nearestModuleRoot(source) + if moduleRoot == "" { + if ctx != nil && strings.TrimSpace(ctx.PackagePath) == strings.TrimSpace(pkgPath) { + if dir := strings.TrimSpace(ctx.PackageDir); dir != "" { + if filepath.IsAbs(dir) { + return dir + } + } + } + return "" + } + modulePath := detectModulePath(moduleRoot) + if modulePath != "" { + if rel, ok := packagePathRelative(modulePath, pkgPath); ok { + if rel == "" { + return moduleRoot + } + return filepath.Join(moduleRoot, filepath.FromSlash(rel)) + } + } + if ctx != nil && strings.TrimSpace(ctx.PackagePath) == strings.TrimSpace(pkgPath) { + if dir := strings.TrimSpace(ctx.PackageDir); dir != "" { + if filepath.IsAbs(dir) { + return dir + } + return filepath.Join(moduleRoot, filepath.FromSlash(dir)) + } + } + return "" +} + +func detectModulePath(moduleRoot string) string { + if moduleRoot == "" { + return "" + } + data, err := os.ReadFile(filepath.Join(moduleRoot, "go.mod")) + if err != nil { + return "" + } + lines := strings.Split(string(data), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "module ") { + continue + } + return strings.TrimSpace(strings.TrimPrefix(line, "module ")) + } + return "" +} + +func packagePathRelative(modulePath, pkgPath string) (string, bool) { + modulePath = strings.TrimSpace(modulePath) + pkgPath = strings.TrimSpace(pkgPath) + if modulePath == "" || pkgPath == "" { + return "", false + } + if pkgPath == modulePath { + return "", true + } + if !strings.HasPrefix(pkgPath, modulePath+"/") { + return "", false + } + return strings.TrimPrefix(pkgPath, modulePath+"/"), true +} + +func nearestModuleRoot(source *shape.Source) string { + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + current := filepath.Dir(strings.TrimSpace(source.Path)) + for current != "" && current != string(filepath.Separator) && current != "." { + if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { + return current + } + parent := filepath.Dir(current) + if parent == current { + break + } + current = parent + } + return "" +} + +func parseNamedStructType(pkgDir, typeName string) reflect.Type { + fset := token.NewFileSet() + pkgs, err := parser.ParseDir(fset, pkgDir, nil, parser.ParseComments) + if err != nil || len(pkgs) == 0 { + return nil + } + specs := map[string]*ast.TypeSpec{} + for _, pkg := range pkgs { + for _, file := range pkg.Files { + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok || typeSpec.Name == nil { + continue + } + specs[typeSpec.Name.Name] = typeSpec + } + } + } + } + cache := map[string]reflect.Type{} + inProgress := map[string]bool{} + var buildNamed func(name string) reflect.Type + var buildExpr func(expr ast.Expr) reflect.Type + + buildNamed = func(name string) reflect.Type { + if cached, ok := cache[name]; ok { + return cached + } + if inProgress[name] { + return reflect.TypeOf(new(interface{})).Elem() + } + spec := specs[name] + if spec == nil { + return nil + } + inProgress[name] = true + rType := buildExpr(spec.Type) + delete(inProgress, name) + if rType != nil { + cache[name] = rType + } + return rType + } + + buildExpr = func(expr ast.Expr) reflect.Type { + switch actual := expr.(type) { + case *ast.Ident: + switch actual.Name { + case "string": + return reflect.TypeOf("") + case "bool": + return reflect.TypeOf(true) + case "int": + return reflect.TypeOf(int(0)) + case "int8": + return reflect.TypeOf(int8(0)) + case "int16": + return reflect.TypeOf(int16(0)) + case "int32": + return reflect.TypeOf(int32(0)) + case "int64": + return reflect.TypeOf(int64(0)) + case "uint": + return reflect.TypeOf(uint(0)) + case "uint8": + return reflect.TypeOf(uint8(0)) + case "uint16": + return reflect.TypeOf(uint16(0)) + case "uint32": + return reflect.TypeOf(uint32(0)) + case "uint64": + return reflect.TypeOf(uint64(0)) + case "float32": + return reflect.TypeOf(float32(0)) + case "float64": + return reflect.TypeOf(float64(0)) + case "interface{}", "any": + return reflect.TypeOf(new(interface{})).Elem() + default: + return buildNamed(actual.Name) + } + case *ast.StarExpr: + if inner := buildExpr(actual.X); inner != nil { + return reflect.PtrTo(inner) + } + case *ast.ArrayType: + if actual.Len == nil { + if inner := buildExpr(actual.Elt); inner != nil { + return reflect.SliceOf(inner) + } + } + case *ast.MapType: + key := buildExpr(actual.Key) + value := buildExpr(actual.Value) + if key != nil && value != nil { + return reflect.MapOf(key, value) + } + case *ast.InterfaceType: + return reflect.TypeOf(new(interface{})).Elem() + case *ast.SelectorExpr: + if ident, ok := actual.X.(*ast.Ident); ok && actual.Sel != nil { + if ident.Name == "time" && actual.Sel.Name == "Time" { + return reflect.TypeOf(time.Time{}) + } + if resolved, err := types.LookupType(extension.Config.Types.Lookup, ident.Name+"."+actual.Sel.Name); err == nil && resolved != nil { + return resolved + } + } + case *ast.StructType: + fields := make([]reflect.StructField, 0, len(actual.Fields.List)) + seen := map[string]bool{} + for _, field := range actual.Fields.List { + if field == nil { + continue + } + fieldType := buildExpr(field.Type) + if fieldType == nil { + continue + } + tag := reflect.StructTag("") + if field.Tag != nil { + tag = reflect.StructTag(strings.Trim(field.Tag.Value, "`")) + } + if len(field.Names) == 0 { + if name := exportedEmbeddedFieldName(field.Type); name != "" { + if seen[name] { + continue + } + seen[name] = true + fields = append(fields, reflect.StructField{Name: name, Type: fieldType, Tag: tag, Anonymous: true}) + } + continue + } + for _, name := range field.Names { + if name == nil || !name.IsExported() { + continue + } + if seen[name.Name] { + continue + } + seen[name.Name] = true + fields = append(fields, reflect.StructField{Name: name.Name, Type: fieldType, Tag: tag}) + } + } + if len(fields) > 0 { + return reflect.StructOf(fields) + } + } + return nil + } + return buildNamed(typeName) +} + +func exportedEmbeddedFieldName(expr ast.Expr) string { + switch actual := expr.(type) { + case *ast.Ident: + if actual.IsExported() { + return actual.Name + } + case *ast.SelectorExpr: + if actual.Sel != nil && actual.Sel.IsExported() { + return actual.Sel.Name + } + case *ast.StarExpr: + return exportedEmbeddedFieldName(actual.X) + } + return "" +} + +func materializedSummarySchema(summaryType reflect.Type, summaryName string, ctx *typectx.Context) *state.Schema { + if summaryType == nil { + return nil + } + if summaryType.Kind() != reflect.Ptr { + summaryType = reflect.PtrTo(summaryType) + } + schema := state.NewSchema(summaryType) + typeName := strings.TrimSpace(summarySchemaName(summaryName)) + if typeName != "" { + schema.Name = typeName + if typeExpr, typePkg := summarySchemaTypeRef(typeName, ctx); typeExpr != "" { + schema.DataType = typeExpr + schema.Package = typePkg + if ctx != nil { + schema.PackagePath = strings.TrimSpace(ctx.PackagePath) + } + } + } + schema.EnsurePointer() + return schema +} + +func refineSummarySchemas(resource *view.Resource) { + if resource == nil { + return + } + visited := map[*view.View]bool{} + for _, aView := range resource.Views { + refineViewSummarySchemas(aView, visited) + } +} + +func materializeResourceTypes(resource *view.Resource, planned []*plan.View, source *shape.Source, ctx *typectx.Context) { + if resource == nil { + return + } + seen := map[string]bool{} + plannedByName := map[string]*plan.View{} + for _, item := range planned { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + plannedByName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, item := range resource.Types { + if item == nil { + continue + } + name := strings.ToLower(strings.TrimSpace(item.Name)) + if name == "" { + continue + } + seen[name] = true + } + visited := map[*view.View]bool{} + for _, aView := range resource.Views { + collectViewTypes(aView, resource, seen, visited, plannedByName, source, ctx) + } +} + +func collectViewTypes(aView *view.View, resource *view.Resource, seen map[string]bool, visited map[*view.View]bool, plannedByName map[string]*plan.View, source *shape.Source, ctx *typectx.Context) { + if aView == nil || resource == nil || visited[aView] { + return + } + visited[aView] = true + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + collectViewTypes(&rel.Of.View, resource, seen, visited, plannedByName, source, ctx) + } + if aView.Template != nil && aView.Template.Summary != nil { + addSchemaTypeDefinition(resource, aView.Template.Summary.Schema, seen) + } + addViewTypeDefinition(resource, aView, seen, plannedByName, source, ctx) +} + +func addViewTypeDefinition(resource *view.Resource, aView *view.View, seen map[string]bool, plannedByName map[string]*plan.View, source *shape.Source, ctx *typectx.Context) { + if resource == nil || aView == nil { + return + } + baseName := strings.TrimSpace(aView.Ref) + if baseName == "" { + baseName = strings.TrimSpace(aView.Name) + } + typeName := state.SanitizeTypeName(baseName) + "View" + key := strings.ToLower(typeName) + if seen[key] { + return + } + def := &view.TypeDefinition{ + Name: typeName, + Package: viewSchemaPackage(aView), + ModulePath: viewSchemaModulePath(aView), + Ptr: viewSchemaPtr(aView), + } + fieldNames := map[string]bool{} + typedFields := collectTypedViewDefinitionFields(aView, plannedByName, source, ctx, typeName) + if len(typedFields) > 0 { + for _, field := range typedFields { + addTypeDefinitionField(def, fieldNames, field) + } + } else { + for _, column := range aView.Columns { + if field := typeDefinitionFieldFromColumn(aView, column); field != nil { + addTypeDefinitionField(def, fieldNames, field) + } + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + if field := typeDefinitionFieldFromRelation(rel); field != nil { + addTypeDefinitionField(def, fieldNames, field) + } + if field := typeDefinitionFieldFromRelationSummary(rel); field != nil { + addTypeDefinitionField(def, fieldNames, field) + } + } + if len(def.Fields) == 0 { + return + } + resource.Types = append(resource.Types, def) + seen[key] = true +} + +func collectTypedViewDefinitionFields(aView *view.View, plannedByName map[string]*plan.View, source *shape.Source, ctx *typectx.Context, typeName string) []*view.Field { + var result []*view.Field + seen := map[string]bool{} + appendFields := func(fields []*view.Field) { + for _, field := range fields { + if field == nil { + continue + } + name := strings.TrimSpace(field.Name) + if name == "" || seen[name] { + continue + } + seen[name] = true + result = append(result, field) + } + } + appendFields(typeDefinitionFieldsFromSchema(aView.Schema)) + appendFields(typeDefinitionFieldsFromReflectType(resolveViewSchemaType(source, ctx, aView, typeName))) + if len(result) == 0 { + appendFields(typeDefinitionFieldsFromPlannedView(plannedViewFor(aView, plannedByName))) + } + return result +} + +func addTypeDefinitionField(def *view.TypeDefinition, names map[string]bool, field *view.Field) { + if def == nil || field == nil { + return + } + name := strings.TrimSpace(field.Name) + if name == "" || names[name] { + return + } + names[name] = true + def.AddField(field) +} + +func plannedViewFor(aView *view.View, plannedByName map[string]*plan.View) *plan.View { + if aView == nil || len(plannedByName) == 0 { + return nil + } + for _, key := range []string{strings.TrimSpace(aView.Ref), strings.TrimSpace(aView.Name)} { + if key == "" { + continue + } + if item := plannedByName[strings.ToLower(key)]; item != nil { + return item + } + } + return nil +} + +func addSchemaTypeDefinition(resource *view.Resource, schema *state.Schema, seen map[string]bool) { + addSchemaTypeDefinitionWithName(resource, schema, strings.TrimSpace(summarySchemaName(schemaName(schema))), seen) +} + +func addSchemaTypeDefinitionWithName(resource *view.Resource, schema *state.Schema, typeName string, seen map[string]bool) { + if resource == nil || schema == nil { + return + } + name := strings.TrimSpace(typeName) + if name == "" { + return + } + key := strings.ToLower(name) + if seen[key] { + return + } + cloned := schema.Clone() + cloned.Name = name + if cloned.DataType == "" && cloned.Type() != nil { + cloned.DataType = cloned.TypeName() + } + resource.Types = append(resource.Types, &view.TypeDefinition{ + Name: name, + DataType: strings.TrimSpace(cloned.DataType), + Cardinality: cloned.Cardinality, + Package: strings.TrimSpace(cloned.Package), + ModulePath: firstNonEmpty(strings.TrimSpace(cloned.ModulePath), strings.TrimSpace(cloned.PackagePath)), + Schema: cloned, + }) + seen[key] = true +} + +func schemaName(schema *state.Schema) string { + if schema == nil { + return "" + } + return schema.Name +} + +func viewSchemaPackage(aView *view.View) string { + if aView == nil || aView.Schema == nil { + return "" + } + return strings.TrimSpace(aView.Schema.Package) +} + +func viewSchemaModulePath(aView *view.View) string { + if aView == nil || aView.Schema == nil { + return "" + } + return firstNonEmpty(strings.TrimSpace(aView.Schema.ModulePath), strings.TrimSpace(aView.Schema.PackagePath)) +} + +func viewSchemaPtr(aView *view.View) bool { + if aView == nil || aView.Schema == nil { + return false + } + if rType := aView.Schema.Type(); rType != nil { + if rType.Kind() == reflect.Slice { + rType = rType.Elem() + } + return rType.Kind() == reflect.Ptr + } + typeName := strings.TrimSpace(firstNonEmpty(aView.Schema.DataType, aView.Schema.Name)) + return strings.HasPrefix(typeName, "*") +} + +func typeDefinitionFieldFromColumn(aView *view.View, column *view.Column) *view.Field { + if column == nil { + return nil + } + fieldName := strings.TrimSpace(column.FieldName()) + if fieldName == "" && column.Field() != nil { + fieldName = strings.TrimSpace(column.Field().Name) + } + if fieldName == "" { + caseFormat := text.CaseFormatUpperCamel + if aView != nil && aView.CaseFormat != "" { + caseFormat = aView.CaseFormat + } + fieldName = state.StructFieldName(caseFormat, column.Name) + } + if fieldName == "" { + return nil + } + schema := columnFieldSchema(column) + if schema == nil { + return nil + } + return &view.Field{ + Name: fieldName, + Column: strings.TrimSpace(column.DatabaseColumn), + FromName: fieldName, + Schema: schema, + Tag: strings.TrimSpace(column.Tag), + Cardinality: schema.Cardinality, + } +} + +func columnFieldSchema(column *view.Column) *state.Schema { + if column == nil { + return nil + } + if rType := column.ColumnType(); rType != nil { + return schemaFromReflectType(rType) + } + if dataType := strings.TrimSpace(column.DataType); dataType != "" { + return &state.Schema{DataType: dataType, Cardinality: state.One} + } + return nil +} + +func typeDefinitionFieldFromRelation(rel *view.Relation) *view.Field { + if rel == nil || rel.Of == nil { + return nil + } + typeName := relationTypeName(&rel.Of.View) + if typeName == "" || strings.TrimSpace(rel.Holder) == "" { + return nil + } + schema := relationSchema(&rel.Of.View, typeName, rel.Cardinality) + return &view.Field{ + Name: strings.TrimSpace(rel.Holder), + Schema: schema, + Cardinality: rel.Cardinality, + } +} + +func typeDefinitionFieldFromRelationSummary(rel *view.Relation) *view.Field { + if rel == nil || rel.Of == nil || rel.Of.View.Template == nil || rel.Of.View.Template.Summary == nil || rel.Of.View.Template.Summary.Schema == nil { + return nil + } + name := strings.TrimSpace(rel.Of.View.Template.Summary.Name) + if name == "" { + return nil + } + schema := rel.Of.View.Template.Summary.Schema.Clone() + schema.EnsurePointer() + return &view.Field{ + Name: name, + Schema: schema, + Tag: `json:",omitempty" yaml:",omitempty" sqlx:"-"`, + } +} + +func relationTypeName(aView *view.View) string { + if aView == nil { + return "" + } + baseName := strings.TrimSpace(aView.Ref) + if baseName == "" { + baseName = strings.TrimSpace(aView.Name) + } + if baseName == "" { + return "" + } + return state.SanitizeTypeName(baseName) + "View" +} + +func relationSchema(aView *view.View, typeName string, cardinality state.Cardinality) *state.Schema { + schema := &state.Schema{ + Name: typeName, + DataType: "*" + typeName, + Cardinality: cardinality, + } + if aView != nil && aView.Schema != nil { + schema.Package = strings.TrimSpace(aView.Schema.Package) + schema.PackagePath = strings.TrimSpace(aView.Schema.PackagePath) + schema.ModulePath = firstNonEmpty(strings.TrimSpace(aView.Schema.ModulePath), strings.TrimSpace(aView.Schema.PackagePath)) + } + return schema +} + +func typeDefinitionFieldsFromSchema(schema *state.Schema) []*view.Field { + if schema == nil || schema.Type() == nil { + return nil + } + rType := schema.Type() + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + result := make([]*view.Field, 0, rType.NumField()) + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if !field.IsExported() { + continue + } + result = append(result, &view.Field{ + Name: field.Name, + Schema: schemaFromReflectType(field.Type), + Tag: string(field.Tag), + FromName: field.Name, + Cardinality: state.One, + }) + } + return result +} + +func typeDefinitionFieldsFromPlannedView(item *plan.View) []*view.Field { + if item == nil { + return nil + } + return typeDefinitionFieldsFromReflectType(bestSchemaType(item)) +} + +func typeDefinitionFieldsFromReflectType(rType reflect.Type) []*view.Field { + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + result := make([]*view.Field, 0, rType.NumField()) + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if !field.IsExported() { + continue + } + result = append(result, &view.Field{ + Name: field.Name, + Schema: schemaFromReflectType(field.Type), + Tag: string(field.Tag), + FromName: field.Name, + Cardinality: state.One, + }) + } + return result +} + +func schemaFromReflectType(rType reflect.Type) *state.Schema { + if rType == nil { + return nil + } + schema := state.NewSchema(rType) + if schema == nil { + return nil + } + if schema.Name == "" && schema.DataType == "" { + schema.DataType = rType.String() + if schema.Cardinality == "" { + schema.Cardinality = state.One + } + } + return schema +} + +func synthesizeViewSchemaType(aView *view.View) reflect.Type { + return synthesizeViewSchemaTypeWithOptions(aView, false) +} + +func synthesizeViewSchemaTypeWithOptions(aView *view.View, includeVelty bool) reflect.Type { + if aView == nil || len(aView.Columns) == 0 { + return nil + } + fields := make([]reflect.StructField, 0, len(aView.Columns)) + seen := map[string]bool{} + for _, column := range aView.Columns { + structField := viewStructFieldFromColumn(aView, column, includeVelty) + if structField == nil { + continue + } + if seen[structField.Name] { + continue + } + seen[structField.Name] = true + fields = append(fields, *structField) + } + if len(fields) == 0 { + return nil + } + return reflect.PtrTo(reflect.StructOf(fields)) +} + +func viewStructFieldFromColumn(aView *view.View, column *view.Column, includeVelty bool) *reflect.StructField { + if column == nil { + return nil + } + schema := columnFieldSchema(column) + if schema == nil || schema.Type() == nil { + return nil + } + fieldName := strings.TrimSpace(column.FieldName()) + if fieldName == "" && column.Field() != nil { + fieldName = strings.TrimSpace(column.Field().Name) + } + if fieldName == "" { + caseFormat := text.CaseFormatUpperCamel + if aView != nil && aView.CaseFormat != "" { + caseFormat = aView.CaseFormat + } + fieldName = state.StructFieldName(caseFormat, column.Name) + } + fieldName = strings.TrimSpace(fieldName) + if fieldName == "" { + return nil + } + tag := strings.TrimSpace(column.Tag) + sqlxTag := strings.TrimSpace(strings.TrimSpace(column.DatabaseColumn)) + if sqlxTag == "" { + sqlxTag = strings.TrimSpace(column.Name) + } + if sqlxTag != "" && !strings.Contains(tag, `sqlx:"`) { + if tag != "" { + tag += " " + } + tag += fmt.Sprintf(`sqlx:"%s"`, sqlxTag) + } + if includeVelty && !strings.Contains(tag, `velty:"`) { + veltyNames := []string{column.Name} + if fieldName != "" && fieldName != column.Name { + veltyNames = append(veltyNames, fieldName) + } + if tag != "" { + tag += " " + } + tag += fmt.Sprintf(`velty:"names=%s"`, strings.Join(veltyNames, "|")) + } + return &reflect.StructField{ + Name: fieldName, + Type: schema.Type(), + Tag: reflect.StructTag(tag), + } +} + +func applyVeltyAliasesToExecInputViews(resource *view.Resource, pResult *plan.Result) { + if resource == nil || pResult == nil || !planUsesVelty(resource, pResult) { + return + } + viewNames := map[string]bool{} + for _, item := range pResult.States { + if item == nil || item.In == nil || item.In.Kind != state.KindView { + continue + } + viewName := strings.TrimSpace(item.Name) + if name := strings.TrimSpace(item.In.Name); name != "" { + viewName = name + } + if viewName == "" { + continue + } + viewNames[strings.ToLower(viewName)] = true + } + if len(viewNames) == 0 { + return + } + for _, aView := range resource.Views { + if aView == nil || aView.Schema == nil { + continue + } + if !viewNames[strings.ToLower(strings.TrimSpace(aView.Name))] && + !viewNames[strings.ToLower(strings.TrimSpace(aView.Reference.Ref))] { + continue + } + applyVeltyAliasesToViewColumns(aView) + if !schemaNeedsVeltyAliases(aView.Schema.Type()) { + continue + } + if rebuilt := synthesizeViewSchemaTypeWithOptions(aView, true); rebuilt != nil { + aView.Schema.SetType(rebuilt) + aView.Schema.EnsurePointer() + continue + } + if rebuilt := ensureSchemaTypeVeltyAliases(aView.Schema.Type()); rebuilt != nil { + aView.Schema.SetType(rebuilt) + } + } +} + +func planUsesVelty(resource *view.Resource, pResult *plan.Result) bool { + if pResult == nil { + return false + } + for _, route := range pResult.Components { + if route == nil { + continue + } + method := strings.ToUpper(strings.TrimSpace(route.Method)) + if method != "" && method != "GET" && strings.TrimSpace(route.Handler) == "" { + return true + } + } + if resource != nil { + for _, aView := range resource.Views { + if aView != nil && aView.Mode == view.ModeExec { + return true + } + } + } + return false +} + +func schemaNeedsVeltyAliases(rType reflect.Type) bool { + if rType == nil { + return true + } + if rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return false + } + for i := 0; i < rType.NumField(); i++ { + if strings.TrimSpace(rType.Field(i).Tag.Get("velty")) != "" { + return false + } + } + return true +} + +func applyVeltyAliasesToViewColumns(aView *view.View) { + if aView == nil { + return + } + for _, column := range aView.Columns { + if column == nil || strings.Contains(column.Tag, `velty:"`) { + continue + } + fieldName := strings.TrimSpace(column.FieldName()) + if fieldName == "" && column.Field() != nil { + fieldName = strings.TrimSpace(column.Field().Name) + } + if fieldName == "" { + caseFormat := text.CaseFormatUpperCamel + if aView.CaseFormat != "" { + caseFormat = aView.CaseFormat + } + fieldName = state.StructFieldName(caseFormat, column.Name) + } + veltyNames := []string{column.Name} + if fieldName != "" && fieldName != column.Name { + veltyNames = append(veltyNames, fieldName) + } + tag := strings.TrimSpace(column.Tag) + if tag != "" { + tag += " " + } + tag += fmt.Sprintf(`velty:"names=%s"`, strings.Join(veltyNames, "|")) + column.Tag = strings.TrimSpace(tag) + } +} + +func ensureSchemaTypeVeltyAliases(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + original := rType + isSlice := false + if rType.Kind() == reflect.Slice { + isSlice = true + rType = rType.Elem() + } + wasPtr := false + if rType.Kind() == reflect.Ptr { + wasPtr = true + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + fields := make([]reflect.StructField, 0, rType.NumField()) + changed := false + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + tag := string(field.Tag) + if strings.TrimSpace(field.Tag.Get("velty")) == "" { + sqlxName := summaryTagName(field.Tag.Get("sqlx")) + if sqlxName == "" { + sqlxName = field.Name + } + veltyNames := []string{sqlxName} + if field.Name != "" && field.Name != sqlxName { + veltyNames = append(veltyNames, field.Name) + } + if strings.TrimSpace(tag) != "" { + tag += " " + } + tag += fmt.Sprintf(`velty:"names=%s"`, strings.Join(veltyNames, "|")) + changed = true + } + field.Tag = reflect.StructTag(strings.TrimSpace(tag)) + fields = append(fields, field) + } + if !changed { + return original + } + rebuilt := reflect.StructOf(fields) + if wasPtr { + rebuilt = reflect.PtrTo(rebuilt) + } + if isSlice { + rebuilt = reflect.SliceOf(rebuilt) + } + return rebuilt +} + +func refineViewSummarySchemas(aView *view.View, visited map[*view.View]bool) { + if aView == nil || visited[aView] { + return + } + visited[aView] = true + if aView.Template != nil && aView.Template.Summary != nil && aView.Template.Summary.Schema != nil { + if refined := refineSummarySchemaType(aView.Template.Summary.Schema, aView); refined != nil { + aView.Template.Summary.Schema = refined + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + refineViewSummarySchemas(&rel.Of.View, visited) + } +} + +func refineSummarySchemaType(summarySchema *state.Schema, ownerView *view.View) *state.Schema { + if summarySchema == nil || ownerView == nil { + return nil + } + summaryType := summarySchema.Type() + if summaryType == nil { + return nil + } + if summaryType.Kind() == reflect.Ptr { + summaryType = summaryType.Elem() + } + if summaryType.Kind() != reflect.Struct { + return nil + } + ownerFields := map[string]reflect.StructField{} + if ownerSchema := ownerView.Schema; ownerSchema != nil && ownerSchema.CompType() != nil { + ownerType := ownerSchema.CompType() + if ownerType.Kind() == reflect.Ptr { + ownerType = ownerType.Elem() + } + if ownerType.Kind() == reflect.Struct { + for i := 0; i < ownerType.NumField(); i++ { + field := ownerType.Field(i) + ownerFields[strings.ToUpper(strings.TrimSpace(field.Name))] = field + if sqlxName := summaryTagName(field.Tag.Get("sqlx")); sqlxName != "" { + ownerFields[strings.ToUpper(sqlxName)] = field + } + } + } + } + for _, column := range ownerView.Columns { + if column == nil { + continue + } + columnType := summaryColumnType(column) + if columnType == nil { + continue + } + fieldName := strings.TrimSpace(column.FieldName()) + if fieldName == "" { + fieldName = strings.TrimSpace(column.Name) + } + field := reflect.StructField{Name: fieldName, Type: columnType, Tag: reflect.StructTag(column.Tag)} + if key := strings.ToUpper(strings.TrimSpace(column.Name)); key != "" { + if _, ok := ownerFields[key]; ok { + continue + } + ownerFields[key] = field + } + if key := strings.ToUpper(strings.TrimSpace(column.DatabaseColumn)); key != "" { + if _, ok := ownerFields[key]; ok { + continue + } + ownerFields[key] = field + } + } + if len(ownerFields) == 0 { + return nil + } + fields := make([]reflect.StructField, 0, summaryType.NumField()) + changed := false + for i := 0; i < summaryType.NumField(); i++ { + field := summaryType.Field(i) + if ownerField, ok := ownerFields[strings.ToUpper(summaryLookupName(field))]; ok && ownerField.Type != nil && ownerField.Type != field.Type { + field.Type = ownerField.Type + changed = true + } + fields = append(fields, field) + } + if !changed { + return nil + } + refinedType := reflect.StructOf(fields) + refined := summarySchema.Clone() + refined.SetType(refinedType) + return refined +} + +func refreshInlineSummarySchemas(ctx context.Context, resource *view.Resource) { + if resource == nil { + return + } + visited := map[*view.View]bool{} + for _, aView := range resource.Views { + refreshViewInlineSummarySchema(ctx, resource, aView, visited) + } +} + +// RefineSummarySchemas reapplies summary schema refinement using current view schema/column metadata. +// This is useful after late column discovery updated view columns post-load. +func RefineSummarySchemas(resource *view.Resource) { + if resource == nil { + return + } + refineSummarySchemas(resource) +} + +func refreshViewInlineSummarySchema(ctx context.Context, resource *view.Resource, aView *view.View, visited map[*view.View]bool) { + if aView == nil || visited[aView] { + return + } + visited[aView] = true + if aView.GetResource() == nil { + aView.SetResource(resource) + } + if shouldRefreshInlineSummarySchema(aView) { + restore := suppressInlineTemplateURLs(aView.Template) + _ = aView.Template.Init(ctx, resource, aView) + restore() + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + if rel.Of.View.GetResource() == nil { + rel.Of.View.SetResource(resource) + } + refreshViewInlineSummarySchema(ctx, resource, &rel.Of.View, visited) + } +} + +func shouldRefreshInlineSummarySchema(aView *view.View) bool { + if aView == nil || aView.Template == nil || aView.Template.Summary == nil { + return false + } + if aView.Connector == nil || strings.TrimSpace(aView.Connector.Ref) == "" { + return false + } + if strings.TrimSpace(aView.Template.Summary.Source) == "" { + return false + } + if strings.TrimSpace(aView.Template.Source) == "" && strings.TrimSpace(aView.Template.SourceURL) == "" { + return false + } + return true +} + +func suppressInlineTemplateURLs(tmpl *view.Template) func() { + if tmpl == nil { + return func() {} + } + sourceURL := tmpl.SourceURL + summarySourceURL := "" + if strings.TrimSpace(tmpl.Source) != "" { + tmpl.SourceURL = "" + } + if tmpl.Summary != nil { + summarySourceURL = tmpl.Summary.SourceURL + if strings.TrimSpace(tmpl.Summary.Source) != "" { + tmpl.Summary.SourceURL = "" + } + } + return func() { + tmpl.SourceURL = sourceURL + if tmpl.Summary != nil { + tmpl.Summary.SourceURL = summarySourceURL + } + } +} + +func summaryColumnType(column *view.Column) reflect.Type { + if column == nil { + return nil + } + if rType := column.ColumnType(); rType != nil { + return rType + } + switch strings.ToLower(strings.TrimSpace(column.DataType)) { + case "int", "integer", "smallint", "signed", "int32": + if column.Nullable { + return reflect.TypeOf((*int)(nil)) + } + return reflect.TypeOf(int(0)) + case "int64", "bigint": + if column.Nullable { + return reflect.TypeOf((*int64)(nil)) + } + return reflect.TypeOf(int64(0)) + case "float", "float32", "real": + if column.Nullable { + return reflect.TypeOf((*float32)(nil)) + } + return reflect.TypeOf(float32(0)) + case "float64", "double", "numeric", "decimal": + if column.Nullable { + return reflect.TypeOf((*float64)(nil)) + } + return reflect.TypeOf(float64(0)) + case "bool", "boolean": + if column.Nullable { + return reflect.TypeOf((*bool)(nil)) + } + return reflect.TypeOf(false) + case "string", "text", "varchar", "char", "uuid", "json", "jsonb", "": + if column.Nullable { + return reflect.TypeOf((*string)(nil)) + } + return reflect.TypeOf("") + default: + return nil + } +} + +func summarySchemaName(summaryName string) string { + summaryName = strings.TrimSpace(summaryName) + if summaryName == "" { + return "" + } + if strings.HasSuffix(summaryName, "View") { + return summaryName + } + return exportedSummaryTypeName(summaryName) + "View" +} + +func summarySchemaTypeRef(typeName string, ctx *typectx.Context) (string, string) { + typeName = strings.TrimSpace(typeName) + if typeName == "" { + return "", "" + } + if ctx != nil { + if pkgAlias := strings.TrimSpace(ctx.PackageName); pkgAlias != "" { + return "*" + pkgAlias + "." + typeName, pkgAlias + } + if pkgPath := strings.TrimSpace(ctx.PackagePath); pkgPath != "" { + alias := summaryPackageAlias(pkgPath, ctx) + return "*" + alias + "." + typeName, alias + } + } + return "*" + typeName, "" +} + +func summaryPackageAlias(pkgPath string, ctx *typectx.Context) string { + pkgPath = strings.TrimSpace(pkgPath) + if pkgPath == "" { + return "" + } + if ctx != nil { + for _, item := range ctx.Imports { + if strings.TrimSpace(item.Package) != pkgPath { + continue + } + if alias := strings.TrimSpace(item.Alias); alias != "" { + return alias + } + } + if strings.TrimSpace(ctx.PackagePath) == pkgPath && strings.TrimSpace(ctx.PackageName) != "" { + return strings.TrimSpace(ctx.PackageName) + } + } + if index := strings.LastIndex(pkgPath, "/"); index != -1 && index+1 < len(pkgPath) { + return pkgPath[index+1:] + } + return pkgPath +} + +func exportedSummaryTypeName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + parts := strings.FieldsFunc(name, func(r rune) bool { + return r == '_' || r == '-' || r == ' ' || r == '.' + }) + if len(parts) == 0 { + return "" + } + var b strings.Builder + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + b.WriteString(strings.ToUpper(part[:1])) + if len(part) > 1 { + b.WriteString(part[1:]) + } + } + return b.String() +} + +func summaryLookupName(field reflect.StructField) string { + if sqlxName := summaryTagName(field.Tag.Get("sqlx")); sqlxName != "" { + return sqlxName + } + return strings.TrimSpace(field.Name) +} + +func summaryTagName(tag string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + if strings.HasPrefix(tag, "name=") { + tag = strings.TrimPrefix(tag, "name=") + } + if idx := strings.Index(tag, ","); idx != -1 { + tag = tag[:idx] + } + return strings.TrimSpace(tag) +} + +func allowsDeferredSchema(item *plan.View, mode view.Mode) bool { + if item == nil { + return false + } + if mode != view.ModeQuery { + return false + } + return strings.TrimSpace(item.Table) != "" || strings.TrimSpace(item.SQL) != "" || strings.TrimSpace(item.SQLURI) != "" +} + +func shouldDeferQuerySchemaType(rType reflect.Type, mode view.Mode) bool { + if rType == nil || mode != view.ModeQuery { + return false + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() == reflect.Map || rType.Kind() == reflect.Interface { + return true + } + if rType.Kind() == reflect.Struct { + if cols := inferColumnsFromType(rType); len(cols) > 0 && inferredColumnsArePlaceholders(cols) { + return true + } + } + return false +} + +func bestSchemaType(item *plan.View) reflect.Type { + if item.FieldType != nil { + return normalizeViewSchemaReflectType(item, item.FieldType) + } + if item.ElementType != nil { + return normalizeViewSchemaReflectType(item, item.ElementType) + } + return nil +} + +func normalizeViewSchemaReflectType(item *plan.View, rType reflect.Type) reflect.Type { + if item == nil || rType == nil { + return rType + } + schemaType := strings.TrimSpace(item.SchemaType) + if !strings.HasPrefix(schemaType, "*") { + return rType + } + if strings.EqualFold(strings.TrimSpace(item.Cardinality), string(state.Many)) { + if rType.Kind() == reflect.Slice { + elem := rType.Elem() + if elem.Kind() != reflect.Ptr { + return reflect.SliceOf(reflect.PtrTo(elem)) + } + } + return rType + } + if rType.Kind() != reflect.Ptr { + return reflect.PtrTo(rType) + } + return rType +} + +func stringPtr(value string) *string { + ret := value + return &ret +} + +func boolPtr(value bool) *bool { + ret := value + return &ret +} + +func toViewRelations(input []*plan.Relation) []*view.Relation { + if len(input) == 0 { + return nil + } + result := make([]*view.Relation, 0, len(input)) + for _, item := range input { + if item == nil { + continue + } + relation := &view.Relation{ + Name: item.Name, + Holder: item.Holder, + Cardinality: state.Many, + IncludeColumn: true, + On: toViewLinks(item.On, true), + Of: view.NewReferenceView( + toViewLinks(item.On, false), + view.NewView(item.Ref, item.Table), + ), + } + result = append(result, relation) + } + return result +} + +func toViewLinks(input []*plan.RelationLink, parent bool) view.Links { + if len(input) == 0 { + return nil + } + result := make(view.Links, 0, len(input)) + for _, item := range input { + if item == nil { + continue + } + link := &view.Link{} + if parent { + link.Field = item.ParentField + link.Namespace = item.ParentNamespace + link.Column = item.ParentColumn + } else { + link.Field = item.RefField + link.Namespace = item.RefNamespace + link.Column = item.RefColumn + } + result = append(result, link) + } + return result +} + +func enrichRelationLinkFields(planned []*plan.View) { + if len(planned) == 0 { + return + } + byName := map[string]*plan.View{} + for _, item := range planned { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + byName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, item := range planned { + if item == nil || len(item.Relations) == 0 { + continue + } + for _, rel := range item.Relations { + if rel == nil || len(rel.On) == 0 { + continue + } + parentPlan := item + if parentName := strings.TrimSpace(rel.Parent); parentName != "" { + if candidate, ok := byName[strings.ToLower(parentName)]; ok && candidate != nil { + parentPlan = candidate + } + } + refPlan := byName[strings.ToLower(strings.TrimSpace(rel.Ref))] + for _, link := range rel.On { + if link == nil { + continue + } + if link.ParentField == "" { + if field := fieldNameForColumn(parentPlan, link.ParentColumn); field != "" { + link.ParentField = field + link.ParentNamespace = "" + } + } + if link.RefField == "" { + if field := fieldNameForColumn(refPlan, link.RefColumn); field != "" { + link.RefField = field + link.RefNamespace = "" + } + } + } + } + } +} + +func fieldNameForColumn(item *plan.View, column string) string { + column = strings.TrimSpace(column) + if column == "" { + return "" + } + fallback := pipeline.ExportedName(column) + if item == nil { + return fallback + } + rType := bestSchemaType(item) + if rType == nil { + return fallback + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return fallback + } + normalizedColumn := normalizeRelationColumnName(column) + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if !field.IsExported() || shouldSkipInferredField(field) { + continue + } + candidate := sqlxColumnName(field) + if candidate == "" { + candidate = field.Name + } + if normalizeRelationColumnName(candidate) == normalizedColumn { + return field.Name + } + } + return fallback +} + +func normalizeRelationColumnName(name string) string { + name = strings.TrimSpace(strings.ToLower(name)) + name = strings.ReplaceAll(name, "_", "") + return name +} + +func newSchema(rType reflect.Type, cardinality string) *state.Schema { + if rType == nil { + schema := &state.Schema{} + if cardinality == "many" { + schema.Cardinality = state.Many + } else { + schema.Cardinality = state.One + } + return schema + } + if cardinality == "many" && rType.Kind() != reflect.Slice { + return state.NewSchema(rType, state.WithMany()) + } + return state.NewSchema(rType) +} + +func attachViewRelations(resource *view.Resource, planned []*plan.View) { + if resource == nil || len(planned) == 0 { + return + } + index := resource.Views.Index() + byName := map[string]*plan.View{} + for _, item := range planned { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + byName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, item := range planned { + if item == nil || len(item.Relations) == 0 { + continue + } + candidates := toViewRelations(item.Relations) + for i, relation := range candidates { + if relation == nil || relation.Of == nil { + continue + } + parentName := relationParentName(item, item.Relations, i) + if parentName == "" { + continue + } + parent, err := index.Lookup(parentName) + if err != nil || parent == nil { + continue + } + if plannedParent, ok := byName[strings.ToLower(parentName)]; ok && plannedParent != nil { + parentName = plannedParent.Name + } + refName := strings.TrimSpace(relation.Of.View.Ref) + if refName == "" { + refName = strings.TrimSpace(relation.Of.View.Name) + } + if refName == "" { + continue + } + ref, err := index.Lookup(refName) + if err != nil || ref == nil { + continue + } + if plannedRef, ok := byName[strings.ToLower(refName)]; ok && plannedRef != nil { + if strings.EqualFold(strings.TrimSpace(plannedRef.Cardinality), string(state.One)) { + relation.Cardinality = state.One + } + } + if inferOneToOneRelation(parent, ref, relation) { + relation.Cardinality = state.One + } + relation.Of.View = cloneRelationView(ref, relation.Of.View) + parent.With = append(parent.With, relation) + } + } +} + +func cloneRelationView(ref *view.View, current view.View) view.View { + if ref == nil { + return current + } + cloned := *ref + cloned.Ref = ref.Name + cloned.Name = "" + if currentName := strings.TrimSpace(current.Name); currentName != "" && !strings.EqualFold(currentName, ref.Name) { + cloned.Name = current.Name + } + if ref.Schema != nil { + cloned.Schema = ref.Schema.Clone() + } + if ref.Template != nil { + templateCopy := *ref.Template + if ref.Template.Schema != nil { + templateCopy.Schema = ref.Template.Schema.Clone() + } + if strings.TrimSpace(templateCopy.Source) != "" { + templateCopy.SourceURL = "" + } + if ref.Template.Summary != nil { + summaryCopy := *ref.Template.Summary + if ref.Template.Summary.Schema != nil { + summaryCopy.Schema = ref.Template.Summary.Schema.Clone() + } + if strings.TrimSpace(summaryCopy.Source) != "" { + summaryCopy.SourceURL = "" + } + templateCopy.Summary = &summaryCopy + } + cloned.Template = &templateCopy + } + if cloned.Selector != nil && cloned.Selector.Limit > 0 { + if cloned.Batch == nil { + cloned.Batch = &view.Batch{} + } + if cloned.Batch.Size == 0 || cloned.Batch.Size > 1 { + cloned.Batch.Size = 1 + } + } + return cloned +} + +func hideRelationSummaryLinkFields(relation *view.Relation) { + if relation == nil || relation.Of == nil { + return + } + child := &relation.Of.View + if child.Template == nil || child.Template.Summary == nil || child.Template.Summary.Schema == nil { + return + } + hidden := map[string]bool{} + for _, link := range relation.Of.On { + if link == nil { + continue + } + if field := normalizeRelationColumnName(link.Field); field != "" { + hidden[field] = true + } + if column := normalizeRelationColumnName(link.Column); column != "" { + hidden[column] = true + } + } + if len(hidden) == 0 { + return + } + if refined := hideSummarySchemaFields(child.Template.Summary.Schema, hidden); refined != nil { + child.Template.Summary.Schema = refined + } +} + +func hideSummarySchemaFields(summarySchema *state.Schema, hidden map[string]bool) *state.Schema { + if summarySchema == nil || len(hidden) == 0 { + return nil + } + summaryType := summarySchema.Type() + if summaryType == nil { + return nil + } + isPtr := false + if summaryType.Kind() == reflect.Ptr { + isPtr = true + summaryType = summaryType.Elem() + } + if summaryType.Kind() != reflect.Struct { + return nil + } + fields := make([]reflect.StructField, 0, summaryType.NumField()) + changed := false + for i := 0; i < summaryType.NumField(); i++ { + field := summaryType.Field(i) + if shouldHideSummaryField(field, hidden) { + field.Tag = hideSummaryFieldTag(field.Tag) + changed = true + } + fields = append(fields, field) + } + if !changed { + return nil + } + refinedType := reflect.StructOf(fields) + if isPtr { + refinedType = reflect.PtrTo(refinedType) + } + refined := summarySchema.Clone() + refined.SetType(refinedType) + return refined +} + +func shouldHideSummaryField(field reflect.StructField, hidden map[string]bool) bool { + if len(hidden) == 0 { + return false + } + candidates := []string{ + normalizeRelationColumnName(field.Name), + normalizeRelationColumnName(summaryLookupName(field)), + } + for _, candidate := range candidates { + if candidate != "" && hidden[candidate] { + return true + } + } + return false +} + +var structTagPattern = regexp.MustCompile(`([A-Za-z0-9_]+):"([^"]*)"`) + +func hideSummaryFieldTag(tag reflect.StructTag) reflect.StructTag { + values := map[string]string{} + order := make([]string, 0, 4) + for _, match := range structTagPattern.FindAllStringSubmatch(string(tag), -1) { + key := strings.TrimSpace(match[1]) + if key == "" { + continue + } + if _, ok := values[key]; !ok { + order = append(order, key) + } + values[key] = match[2] + } + for _, item := range []struct { + key string + value string + }{ + {key: "internal", value: "true"}, + } { + if _, ok := values[item.key]; !ok { + order = append(order, item.key) + } + values[item.key] = item.value + } + parts := make([]string, 0, len(order)) + for _, key := range order { + if value, ok := values[key]; ok { + parts = append(parts, fmt.Sprintf(`%s:%q`, key, value)) + } + } + return reflect.StructTag(strings.Join(parts, " ")) +} + +func enrichRelationHolderTypes(resource *view.Resource, planned []*plan.View) error { + if resource == nil || len(planned) == 0 { + return nil + } + index := resource.Views.Index() + byName := map[string]*plan.View{} + for _, item := range planned { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + byName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, item := range planned { + if item == nil || len(item.Relations) == 0 { + continue + } + for i, rel := range item.Relations { + if rel == nil { + continue + } + parentName := relationParentName(item, item.Relations, i) + if parentName == "" { + parentName = item.Name + } + parent, err := index.Lookup(parentName) + if err != nil || parent == nil || parent.Schema == nil { + continue + } + parentType := parent.ComponentType() + if parentType == nil { + continue + } + augmented, changed, err := ensureRelationHolderFields(parentType, &plan.View{Relations: []*plan.Relation{rel}}, byName, index) + if err != nil { + return err + } + if !changed || augmented == nil { + continue + } + if parent.Schema.Cardinality == state.Many { + parentType := parent.Schema.Type() + if parentType != nil && parentType.Kind() == reflect.Slice { + elemType := parentType.Elem() + if elemType.Kind() == reflect.Ptr { + parent.Schema.SetType(reflect.SliceOf(reflect.PtrTo(augmented))) + continue + } + } + parent.Schema.SetType(reflect.SliceOf(augmented)) + continue + } + if parentType := parent.Schema.Type(); parentType != nil && parentType.Kind() == reflect.Ptr { + parent.Schema.SetType(reflect.PtrTo(augmented)) + continue + } + parent.Schema.SetType(augmented) + } + } + return nil +} + +func ensureRelationHolderFields(parentType reflect.Type, item *plan.View, byName map[string]*plan.View, index view.NamedViews) (reflect.Type, bool, error) { + parentType = ensureStructType(parentType) + if parentType == nil || item == nil || len(item.Relations) == 0 { + return parentType, false, nil + } + fields := make([]reflect.StructField, 0, parentType.NumField()+len(item.Relations)) + for i := 0; i < parentType.NumField(); i++ { + fields = append(fields, parentType.Field(i)) + } + changed := false + for _, rel := range item.Relations { + if rel == nil || strings.TrimSpace(rel.Holder) == "" { + continue + } + childName := strings.TrimSpace(rel.Ref) + if childName == "" { + continue + } + childView, err := index.Lookup(childName) + if err != nil || childView == nil { + continue + } + childType := childView.ComponentType() + if childType == nil && childView.Schema != nil { + childType = childView.Schema.Type() + } + if childType == nil { + if childPlanned, ok := byName[strings.ToLower(childName)]; ok && childPlanned != nil { + childType = bestSchemaType(childPlanned) + } + } + if childType == nil { + continue + } + if _, ok := parentType.FieldByName(rel.Holder); !ok && !fieldNameInSlice(fields, rel.Holder) { + fieldType := relationHolderFieldType(childType, childPlannedCardinality(childName, byName)) + if fieldType != nil { + fields = append(fields, reflect.StructField{ + Name: rel.Holder, + Type: fieldType, + Tag: reflect.StructTag(buildRelationHolderTag(rel, childView)), + }) + changed = true + } + } + if summaryField := relationSummaryField(childView); summaryField != nil { + if _, ok := parentType.FieldByName(summaryField.Name); !ok && !fieldNameInSlice(fields, summaryField.Name) { + fields = append(fields, *summaryField) + changed = true + } + } + } + if !changed { + return parentType, false, nil + } + return reflect.StructOf(fields), true, nil +} + +func relationSummaryField(childView *view.View) *reflect.StructField { + if childView == nil || childView.Template == nil || childView.Template.Summary == nil || childView.Template.Summary.Schema == nil { + return nil + } + fieldType := childView.Template.Summary.Schema.Type() + if fieldType == nil { + return nil + } + fieldName := strings.TrimSpace(childView.Template.Summary.Name) + if fieldName == "" { + return nil + } + if fieldType.Kind() == reflect.Struct { + fieldType = reflect.PtrTo(fieldType) + } + return &reflect.StructField{ + Name: fieldName, + Type: fieldType, + Tag: reflect.StructTag(`json:",omitempty" yaml:",omitempty" sqlx:"-"`), + } +} + +func fieldNameInSlice(fields []reflect.StructField, name string) bool { + for _, field := range fields { + if field.Name == name { + return true + } + } + return false +} + +func childPlannedCardinality(childName string, byName map[string]*plan.View) state.Cardinality { + if childPlanned, ok := byName[strings.ToLower(childName)]; ok && childPlanned != nil { + if strings.EqualFold(strings.TrimSpace(childPlanned.Cardinality), string(state.One)) { + return state.One + } + } + return state.Many +} + +func relationHolderFieldType(childType reflect.Type, cardinality state.Cardinality) reflect.Type { + if childType == nil { + return nil + } + childType = normalizeDeferredHolderType(childType) + if cardinality == state.One { + for childType.Kind() == reflect.Slice || childType.Kind() == reflect.Array { + childType = childType.Elem() + } + if childType.Kind() == reflect.Struct { + return reflect.PtrTo(childType) + } + return childType + } + if childType.Kind() == reflect.Slice || childType.Kind() == reflect.Array { + return childType + } + normalized := childType + if normalized.Kind() == reflect.Struct { + normalized = reflect.PtrTo(normalized) + } + return reflect.SliceOf(normalized) +} + +func normalizeDeferredHolderType(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + kind := rType.Kind() + if kind == reflect.Slice || kind == reflect.Array { + elem := rType.Elem() + for elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + if elem.Kind() == reflect.Map || elem.Kind() == reflect.Interface { + return reflect.SliceOf(reflect.TypeOf(struct{}{})) + } + return rType + } + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() == reflect.Map || rType.Kind() == reflect.Interface { + return reflect.TypeOf(struct{}{}) + } + return rType +} + +func ensureStructType(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + return rType +} + +func buildRelationHolderTag(rel *plan.Relation, child *view.View) string { + if rel == nil { + return `json:",omitempty" sqlx:"-"` + } + table := "" + sqlExpr := "" + if child != nil { + table = strings.TrimSpace(child.Table) + if child.Template != nil { + if uri := strings.TrimSpace(child.Template.SourceURL); uri != "" { + sqlExpr = "uri=" + uri + } else if source := strings.TrimSpace(child.Template.Source); source != "" { + sqlExpr = source + } + } + } + tagParts := []string{fmt.Sprintf(`view:",table=%s"`, table)} + if onExpr := buildRelationOnTag(rel); onExpr != "" { + tagParts = append(tagParts, fmt.Sprintf(`on:"%s"`, onExpr)) + } + if sqlExpr != "" { + tagParts = append(tagParts, fmt.Sprintf(`sql:%q`, sqlExpr)) + } + tagParts = append(tagParts, `json:",omitempty"`, `sqlx:"-"`) + return strings.Join(tagParts, " ") +} + +func buildRelationOnTag(rel *plan.Relation) string { + if rel == nil || len(rel.On) == 0 { + return "" + } + parts := make([]string, 0, len(rel.On)) + for _, link := range rel.On { + if link == nil { + continue + } + parentField := firstNonEmpty(strings.TrimSpace(link.ParentField), strings.TrimSpace(link.ParentColumn)) + refField := firstNonEmpty(strings.TrimSpace(link.RefField), strings.TrimSpace(link.RefColumn)) + if parentField == "" || refField == "" { + continue + } + parts = append(parts, fmt.Sprintf("%s:%s=%s:%s", parentField, link.ParentColumn, refField, link.RefColumn)) + } + return strings.Join(parts, ",") +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func bindTemplateParameters(resource *view.Resource) { + if resource == nil || len(resource.Parameters) == 0 { + return + } + params := make([]*state.Parameter, 0, len(resource.Parameters)) + for _, param := range resource.Parameters { + if param == nil || param.In == nil { + continue + } + switch param.In.Kind { + case state.KindOutput, state.KindMeta, state.KindAsync: + continue + } + params = append(params, param) + } + if len(params) == 0 { + return + } + for _, item := range resource.Views { + bindViewTemplateParameters(item, params) + } +} + +func bindViewTemplateParameters(aView *view.View, params []*state.Parameter) { + if aView == nil { + return + } + if aView.Template != nil { + if aView.Template.DeclaredParametersOnly { + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + bindViewTemplateParameters(&rel.Of.View, params) + } + return + } + seen := map[string]bool{} + for _, item := range aView.Template.Parameters { + if item != nil { + seen[strings.ToLower(strings.TrimSpace(item.Name))] = true + } + } + for _, param := range params { + if param == nil || strings.TrimSpace(param.Name) == "" { + continue + } + if param.In != nil && param.In.Kind == state.KindView && strings.EqualFold(strings.TrimSpace(param.In.Name), strings.TrimSpace(aView.Name)) { + continue + } + key := strings.ToLower(strings.TrimSpace(param.Name)) + if seen[key] { + continue + } + aView.Template.Parameters = append(aView.Template.Parameters, param) + seen[key] = true + } + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + bindViewTemplateParameters(&rel.Of.View, params) + } +} + +func inferOneToOneRelation(parent, ref *view.View, relation *view.Relation) bool { + if parent == nil || ref == nil || relation == nil || relation.Of == nil { + return false + } + parentTable := strings.TrimSpace(parent.Table) + refTable := strings.TrimSpace(ref.Table) + if parentTable == "" || refTable == "" || !strings.EqualFold(parentTable, refTable) { + return false + } + if len(relation.On) == 0 || len(relation.Of.On) == 0 { + return false + } + count := len(relation.On) + if len(relation.Of.On) < count { + count = len(relation.Of.On) + } + if count == 0 { + return false + } + for i := 0; i < count; i++ { + parentCol := normalizeRelationColumn(relation.On[i].Column) + refCol := normalizeRelationColumn(relation.Of.On[i].Column) + if parentCol == "" || refCol == "" || !strings.EqualFold(parentCol, refCol) { + return false + } + } + return true +} + +func normalizeRelationColumn(column string) string { + column = strings.TrimSpace(column) + if column == "" { + return "" + } + if idx := strings.LastIndex(column, "."); idx != -1 && idx+1 < len(column) { + column = column[idx+1:] + } + return strings.TrimSpace(column) +} + +func relationParentName(source *plan.View, relations []*plan.Relation, index int) string { + if index >= 0 && index < len(relations) { + item := relations[index] + if item != nil { + if parent := strings.TrimSpace(item.Parent); parent != "" { + return parent + } + for _, link := range item.On { + if link == nil { + continue + } + if parent := strings.TrimSpace(link.ParentNamespace); parent != "" { + return parent + } + } + } + } + if source == nil { + return "" + } + return strings.TrimSpace(source.Name) +} + +func cloneStateParameter(item *plan.State) *state.Parameter { + if item == nil { + return nil + } + param := item.Parameter + if param.In != nil { + in := *param.In + param.In = &in + } + if param.Schema != nil { + schema := *param.Schema + param.Schema = &schema + } + if len(param.Predicates) > 0 { + preds := make([]*extension.PredicateConfig, 0, len(param.Predicates)) + for _, candidate := range param.Predicates { + if candidate == nil { + continue + } + pred := *candidate + if len(candidate.Args) > 0 { + pred.Args = append([]string{}, candidate.Args...) + } + preds = append(preds, &pred) + } + param.Predicates = preds + } + return ¶m +} + +func clonePlanState(item *plan.State) *plan.State { + if item == nil { + return nil + } + cloned := *item + if param := cloneStateParameter(item); param != nil { + cloned.Parameter = *param + } + return &cloned +} + +func normalizeDerivedInputSchema(param *state.Parameter, resource *view.Resource) { + if param == nil || param.In == nil || resource == nil { + return + } + if param.In.Kind != state.KindView { + return + } + viewName := strings.TrimSpace(param.Name) + if name := strings.TrimSpace(param.In.Name); name != "" { + viewName = name + } + aView, _ := resource.View(viewName) + if aView == nil || aView.Schema == nil { + return + } + required := param.Required != nil && *param.Required + if param.Schema == nil { + param.Schema = aView.Schema.Clone() + if required && param.Schema != nil { + param.Schema.Cardinality = state.One + } + return + } + if strings.TrimSpace(param.Schema.Name) == "" { + param.Schema.Name = strings.TrimSpace(aView.Schema.Name) + } + dataType := strings.TrimSpace(param.Schema.DataType) + if dataType == "" || dataType == "?" || dataType == "interface{}" || dataType == "[]interface{}" || dataType == "*interface{}" || dataType == "string" || dataType == "[]string" { + param.Schema.DataType = strings.TrimSpace(aView.Schema.DataType) + if param.Schema.DataType == "" && param.Schema.Name != "" { + param.Schema.DataType = "*" + param.Schema.Name + } + } + if strings.TrimSpace(param.Schema.Package) == "" { + param.Schema.Package = strings.TrimSpace(aView.Schema.Package) + } + if param.Schema.Type() == nil && aView.Schema.Type() != nil { + param.Schema.SetType(aView.Schema.Type()) + } + if resourceUsesVelty(resource) { + if rebuilt := ensureSchemaTypeVeltyAliases(param.Schema.Type()); rebuilt != nil { + param.Schema.SetType(rebuilt) + if aView.Schema != nil && schemaNeedsVeltyAliases(aView.Schema.Type()) { + aView.Schema.SetType(rebuilt) + } + } + } + if param.Schema.Cardinality == "" { + if required { + param.Schema.Cardinality = state.One + } else if aView.Schema.Cardinality != "" { + param.Schema.Cardinality = aView.Schema.Cardinality + } + } +} + +func resourceUsesVelty(resource *view.Resource) bool { + if resource == nil { + return false + } + for _, aView := range resource.Views { + if aView != nil && aView.Mode == view.ModeExec { + return true + } + } + return false +} + +func rootResourceView(resource *view.Resource, planned []*plan.View) *view.View { + if resource == nil { + return nil + } + rootPlan := pickRootView(planned) + if rootPlan == nil || strings.TrimSpace(rootPlan.Name) == "" { + if len(resource.Views) > 0 { + return resource.Views[0] + } + return nil + } + index := resource.Views.Index() + root, _ := index.Lookup(rootPlan.Name) + return root +} + +func inheritRootOutputSchema(param *state.Parameter, root *view.View) { + if param == nil || param.In == nil || root == nil || root.Schema == nil { + return + } + if param.In.Kind != state.KindOutput || !strings.EqualFold(strings.TrimSpace(param.In.Name), "view") { + return + } + dataType := "" + if param.Schema != nil { + dataType = strings.TrimSpace(param.Schema.DataType) + } + if dataType != "" && dataType != "?" { + return + } + if param.Schema == nil { + param.Schema = &state.Schema{} + } + explicit := *param.Schema + schema := *root.Schema + if strings.TrimSpace(explicit.Name) != "" { + schema.Name = strings.TrimSpace(explicit.Name) + } + if dataType := strings.TrimSpace(explicit.DataType); dataType != "" && dataType != "?" { + schema.DataType = dataType + } + if pkg := strings.TrimSpace(explicit.Package); pkg != "" { + schema.Package = pkg + } + if pkgPath := strings.TrimSpace(explicit.PackagePath); pkgPath != "" { + schema.PackagePath = pkgPath + } + if modulePath := strings.TrimSpace(explicit.ModulePath); modulePath != "" { + schema.ModulePath = modulePath + } + if explicit.Cardinality != "" { + schema.Cardinality = explicit.Cardinality + } + if schema.Cardinality == state.One && schema.Type() != nil { + if normalized := collapseSchemaTypeToOne(schema.Type()); normalized != nil { + schema.SetType(normalized) + if strings.TrimSpace(schema.DataType) == "" || strings.HasPrefix(strings.TrimSpace(schema.DataType), "[]") { + schema.DataType = normalized.String() + } + } + } + param.Schema = &schema +} + +func collapseSchemaTypeToOne(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + switch rType.Kind() { + case reflect.Slice: + return rType.Elem() + case reflect.Ptr: + elem := collapseSchemaTypeToOne(rType.Elem()) + if elem == nil { + return nil + } + if elem.Kind() == reflect.Slice { + return elem + } + return reflect.PtrTo(elem) + default: + return rType + } +} + +func inheritRootBodySchema(param *state.Parameter, root *view.View) { + if param == nil || param.In == nil || root == nil || root.Schema == nil { + return + } + if param.In.Kind != state.KindRequestBody { + return + } + if !param.IsAnonymous() { + return + } + if param.Schema == nil { + param.Schema = &state.Schema{} + } + explicit := *param.Schema + schema := *root.Schema + if strings.TrimSpace(explicit.Name) != "" { + schema.Name = strings.TrimSpace(explicit.Name) + } + if dataType := strings.TrimSpace(explicit.DataType); dataType != "" && dataType != "?" { + schema.DataType = dataType + } + if pkg := strings.TrimSpace(explicit.Package); pkg != "" { + schema.Package = pkg + } + if pkgPath := strings.TrimSpace(explicit.PackagePath); pkgPath != "" { + schema.PackagePath = pkgPath + } + if modulePath := strings.TrimSpace(explicit.ModulePath); modulePath != "" { + schema.ModulePath = modulePath + } + if explicit.Cardinality != "" { + schema.Cardinality = explicit.Cardinality + } + if schema.Type() == nil && root.Schema.Type() != nil { + schema.SetType(root.Schema.Type()) + } + if schema.Cardinality == state.One && schema.Type() != nil { + if normalized := collapseSchemaTypeToOne(schema.Type()); normalized != nil { + schema.SetType(normalized) + if strings.TrimSpace(schema.DataType) == "" || strings.HasPrefix(strings.TrimSpace(schema.DataType), "[]") { + schema.DataType = normalized.String() + } + } + } + param.Schema = &schema +} + +func ensureMaterializedOutputSchema(param *state.Parameter, root *view.View, source *shape.Source, ctx *typectx.Context) { + if param == nil || param.In == nil { + return + } + if param.In.Kind != state.KindOutput { + return + } + switch strings.ToLower(strings.TrimSpace(param.In.Name)) { + case "status": + if param.Schema != nil && (param.Schema.Type() != nil || strings.TrimSpace(param.Schema.DataType) != "") { + return + } + param.Schema = state.NewSchema(reflect.TypeOf(response.Status{})) + case "summary": + if (param.Schema == nil || (param.Schema.Type() == nil && strings.TrimSpace(param.Schema.DataType) == "" || strings.TrimSpace(param.Schema.DataType) == "?")) && root != nil && root.Template != nil && root.Template.Summary != nil && root.Template.Summary.Schema != nil { + param.Schema = root.Template.Summary.Schema.Clone() + } + if (param.Schema == nil || (param.Schema.Type() == nil && (strings.TrimSpace(param.Schema.DataType) == "" || strings.TrimSpace(param.Schema.DataType) == "?"))) && strings.TrimSpace(param.Name) != "" { + if summaryType := resolveSummarySchemaType(source, ctx, param.Name); summaryType != nil { + param.Schema = materializedSummarySchema(summaryType, param.Name, ctx) + } + } } - return state.NewSchema(rType) } diff --git a/repository/shape/load/loader_contract_state_test.go b/repository/shape/load/loader_contract_state_test.go new file mode 100644 index 000000000..009cd3e10 --- /dev/null +++ b/repository/shape/load/loader_contract_state_test.go @@ -0,0 +1,23 @@ +package load + +import ( + "reflect" + "testing" +) + +func TestContractStates_PreservesCodecAndHandler(t *testing.T) { + type input struct { + Jwt string `parameter:",kind=header,in=Authorization,errorCode=401" codec:"JwtClaim"` + Run string `parameter:",kind=body,in=run" handler:"Exec"` + } + states := contractStates(reflect.TypeOf(input{})) + if got, want := len(states), 2; got != want { + t.Fatalf("expected %d states, got %d", want, got) + } + if states[0].Output == nil || states[0].Output.Name != "JwtClaim" { + t.Fatalf("expected codec to be preserved, got %#v", states[0].Output) + } + if states[1].Handler == nil || states[1].Handler.Name != "Exec" { + t.Fatalf("expected handler to be preserved, got %#v", states[1].Handler) + } +} diff --git a/repository/shape/load/loader_dql_test.go b/repository/shape/load/loader_dql_test.go new file mode 100644 index 000000000..146420884 --- /dev/null +++ b/repository/shape/load/loader_dql_test.go @@ -0,0 +1,108 @@ +package load_test + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" +) + +func TestLoadComponent_DQLUserMetadataPreservesBitColumns(t *testing.T) { + dqlPath := filepath.Join("..", "..", "..", "e2e", "v1", "dql", "dev", "user", "user_metadata.dql") + dqlPath, err := filepath.Abs(dqlPath) + require.NoError(t, err) + data, err := os.ReadFile(dqlPath) + require.NoError(t, err) + + source := &shape.Source{ + Name: "user_metadata", + Path: dqlPath, + DQL: string(data), + } + planned, err := shapeCompile.New().Compile(context.Background(), source) + require.NoError(t, err) + actualPlan, ok := shapePlan.ResultFrom(planned) + require.True(t, ok) + require.NotNil(t, actualPlan.TypeContext) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/user/mysql_boolean", actualPlan.TypeContext.PackagePath) + t.Logf("typectx: dir=%q name=%q path=%q", actualPlan.TypeContext.PackageDir, actualPlan.TypeContext.PackageName, actualPlan.TypeContext.PackagePath) + + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + root, err := artifact.Resource.Views.Index().Lookup("user_metadata") + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Schema) + require.NotNil(t, root.Schema.Type()) + t.Logf("schema type: %v", root.Schema.Type()) + + names := make([]string, 0, len(root.Columns)) + for _, column := range root.Columns { + if column == nil { + continue + } + names = append(names, column.Name) + } + assert.Contains(t, names, "IS_ENABLED") + assert.Contains(t, names, "IS_ACTIVATED") +} + +func TestLoadComponent_DQLVarsHonorsDeclaredColumnType(t *testing.T) { + dqlPath := filepath.Join("..", "..", "..", "e2e", "v1", "dql", "dev", "vendorsrv", "vars.dql") + dqlPath, err := filepath.Abs(dqlPath) + require.NoError(t, err) + data, err := os.ReadFile(dqlPath) + require.NoError(t, err) + + source := &shape.Source{ + Name: "vars", + Path: dqlPath, + DQL: string(data), + } + planned, err := shapeCompile.New().Compile(context.Background(), source) + require.NoError(t, err) + actualPlan, ok := shapePlan.ResultFrom(planned) + require.True(t, ok) + for _, item := range actualPlan.Views { + if item == nil || item.Name != "main" { + continue + } + if item.Declaration != nil && item.Declaration.ColumnsConfig != nil { + if cfg := item.Declaration.ColumnsConfig["Key3"]; cfg != nil { + t.Logf("planned Key3 dataType=%q", cfg.DataType) + } + } + } + + artifact, err := shapeLoad.New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + root, err := artifact.Resource.Views.Index().Lookup("main") + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Schema) + require.NotNil(t, root.Schema.Type()) + if cfg := root.ColumnsConfig["Key3"]; cfg != nil && cfg.DataType != nil { + t.Logf("Key3 config dataType=%q", *cfg.DataType) + } + t.Logf("vars schema type: %v", root.Schema.Type()) + + var key3Type string + for _, column := range root.Columns { + if column == nil || column.Name != "Key3" { + continue + } + if column.ColumnType() != nil { + key3Type = column.ColumnType().String() + } else { + key3Type = column.DataType + } + } + assert.Equal(t, "bool", key3Type) +} diff --git a/repository/shape/load/loader_test.go b/repository/shape/load/loader_test.go index 20117b4e9..cfdc2e2ac 100644 --- a/repository/shape/load/loader_test.go +++ b/repository/shape/load/loader_test.go @@ -3,7 +3,10 @@ package load import ( "context" "embed" + "os" + "path/filepath" "reflect" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -13,6 +16,11 @@ import ( "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/repository/shape/scan" "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/x" + "github.com/viant/xdatly" + "github.com/viant/xdatly/handler/response" ) //go:embed testdata/*.sql @@ -29,6 +37,65 @@ type reportRow struct { Name string } +type relationTreeRoot struct { + ID int `sqlx:"name=ID"` +} + +type relationTreeChild struct { + ID int `sqlx:"name=ID"` + RootID int `sqlx:"name=RootID"` +} + +type relationTreeGrandChild struct { + ID int `sqlx:"name=ID"` + ChildID int `sqlx:"name=ChildID"` +} + +type cityRow struct { + ID int `sqlx:"name=ID"` + DistrictID int `sqlx:"name=DISTRICT_ID"` +} + +type vendorProductRow struct { + ID int `sqlx:"name=ID"` + VendorID int `sqlx:"name=VENDOR_ID"` +} + +type metaSummaryRow struct { + PageCnt int `sqlx:"name=PAGE_CNT"` + Cnt int `sqlx:"name=CNT"` +} + +type productsMetaSummaryRow struct { + VendorID int `sqlx:"name=VENDOR_ID"` + PageCnt int `sqlx:"name=PAGE_CNT"` + TotalProducts int `sqlx:"name=TOTAL_PRODUCTS"` +} + +type productsOwnerPointerRow struct { + VendorID *int `sqlx:"name=VENDOR_ID"` +} + +type vendorSummaryParentRow struct { + ID int `sqlx:"name=ID"` +} + +type vendorSummaryChildRow struct { + ID int `sqlx:"name=ID"` + VendorID int `sqlx:"name=VENDOR_ID"` +} + +type fieldOnlyUserACLRow struct { + UserID int `sqlx:"name=UserID"` + IsReadOnly int `sqlx:"name=IsReadOnly"` + Feature1 int `sqlx:"name=Feature1"` +} + +type placeholderDistrictRow struct { + Col1 string `sqlx:"name=col_1"` + Col2 string `sqlx:"name=col_2"` +} + type reportSource struct { embeddedFS Rows []reportRow `view:"rows,table=REPORT,connector=dev,cache=c1" sql:"uri=testdata/report.sql"` @@ -36,6 +103,80 @@ type reportSource struct { Status any `parameter:"status,kind=output,in=status"` Job any `parameter:"job,kind=async,in=job"` Meta any `parameter:"meta,kind=meta,in=view.name"` + Route struct{} `component:",path=/v1/api/dev/report,method=GET,connector=dev"` +} + +type typedRouteInput struct { + ID int +} + +type typedRouteOutput struct { + Data []reportRow +} + +type typedTeamRouteInput struct { + TeamID string `parameter:",kind=path,in=teamID"` +} + +type typedTeamRouteOutput struct{} + +type typedRouteSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT" sql:"uri=testdata/report.sql"` + Route xdatly.Component[typedRouteInput, typedRouteOutput] `component:",path=/v1/api/dev/report,method=GET"` +} + +type reportEnabledLoadSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT" sql:"uri=testdata/report.sql"` + Route xdatly.Component[typedRouteInput, typedRouteOutput] `component:",path=/v1/api/dev/report,method=GET,report=true,reportInput=NamedReportInput,reportDimensions=Dims,reportMeasures=Metrics,reportFilters=Predicates,reportOrderBy=Sort,reportLimit=Take,reportOffset=Skip"` +} + +type dynamicRouteInput struct { + Name string +} + +type dynamicRouteOutput struct { + Count int +} + +type namedDynamicRouteInput struct { + Name string `parameter:"name,kind=query,in=name"` +} + +type namedDynamicRouteOutput struct { + response.Status `parameter:",kind=output,in=status" json:",omitempty"` + Data []*reportRow `parameter:",kind=output,in=view" view:"rows,table=REPORT" sql:"uri=testdata/report.sql" anonymous:"true"` +} + +type dynamicRouteSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT" sql:"uri=testdata/report.sql"` + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET"` +} + +type selectorHolderSource struct { + Rows []reportRow `view:"rows,table=REPORT" sql:"uri=testdata/report.sql"` + Route xdatly.Component[typedRouteInput, typedRouteOutput] `component:",path=/v1/api/dev/report,method=GET"` + ViewSelect struct { + Fields []string `parameter:"fields,kind=query,in=_fields,cacheable=false"` + Page int `parameter:"page,kind=query,in=_page,cacheable=false"` + } `querySelector:"rows"` +} + +type routerOnlyInput struct { + ID int `parameter:"id,kind=query,in=id"` +} + +func (*routerOnlyInput) EmbedFS() *embed.FS { return &testFS } + +type routerOnlyOutput struct { + response.Status `parameter:",kind=output,in=status" json:",omitempty"` + Data []*reportRow `parameter:",kind=output,in=view" view:"rows,table=REPORT" sql:"uri=testdata/report.sql" anonymous:"true"` +} + +type routerOnlySource struct { + Route xdatly.Component[routerOnlyInput, routerOnlyOutput] `component:",path=/v1/api/dev/router-only,method=GET"` } func TestLoader_LoadViews(t *testing.T) { @@ -69,6 +210,60 @@ func TestLoader_LoadViews(t *testing.T) { require.NotNil(t, artifacts.Resource.EmbedFS()) } +func TestLoader_LoadComponent_PreservesReportConfig(t *testing.T) { + scanned, err := scan.New().Scan(context.Background(), &shape.Source{Struct: &reportEnabledLoadSource{}}) + require.NoError(t, err) + + planned, err := plan.New().Plan(context.Background(), scanned) + require.NoError(t, err) + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component.Report) + assert.True(t, component.Report.Enabled) + assert.Equal(t, "NamedReportInput", component.Report.Input) + assert.Equal(t, "Dims", component.Report.Dimensions) + assert.Equal(t, "Metrics", component.Report.Measures) + assert.Equal(t, "Predicates", component.Report.Filters) + assert.Equal(t, "Sort", component.Report.OrderBy) + assert.Equal(t, "Take", component.Report.Limit) + assert.Equal(t, "Skip", component.Report.Offset) +} + +func TestLoader_LoadResource(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "report"}, + Plan: &plan.Result{ + EmbedFS: &testFS, + Views: []*plan.View{ + { + Name: "rows", + Table: "REPORT", + Connector: "dev", + SQL: "SELECT ID, NAME FROM REPORT", + SQLURI: "testdata/report.sql", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }, + }, + ViewsByName: map[string]*plan.View{"rows": {Name: "rows"}}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadResource(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + require.Len(t, artifacts.Resource.Views, 1) + assert.Equal(t, "rows", artifacts.Resource.Views[0].Name) + assert.Equal(t, "REPORT", artifacts.Resource.Views[0].Table) +} + // stubPlanSpec is a non-plan-Result implementation of shape.PlanSpec used to // verify that LoadViews() returns an error when given an unexpected plan type. type stubPlanSpec struct{} @@ -85,22 +280,45 @@ func TestLoader_LoadViews_InvalidPlanType(t *testing.T) { func TestLoader_LoadViews_Metadata(t *testing.T) { noLimit := true allowNulls := true + groupable := true + criteria := true + projection := true + orderBy := true + offset := true planned := &shape.PlanResult{ Source: &shape.Source{Name: "meta"}, Plan: &plan.Result{ Views: []*plan.View{ { - Name: "items", - Table: "ITEMS", - Module: "platform/items", - AllowNulls: &allowNulls, - SelectorNamespace: "it", - SelectorNoLimit: &noLimit, - SchemaType: "*ItemView", - Cardinality: "many", - FieldType: reflect.TypeOf([]map[string]interface{}{}), - ElementType: reflect.TypeOf(map[string]interface{}{}), - SQL: "SELECT * FROM ITEMS", + Name: "items", + Table: "ITEMS", + Module: "platform/items", + AllowNulls: &allowNulls, + Groupable: &groupable, + SelectorNamespace: "it", + SelectorNoLimit: &noLimit, + SelectorCriteria: &criteria, + SelectorProjection: &projection, + SelectorOrderBy: &orderBy, + SelectorOffset: &offset, + SelectorFilterable: []string{"*"}, + SelectorOrderByColumns: map[string]string{ + "accountId": "ACCOUNT_ID", + }, + SchemaType: "*ItemView", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM ITEMS", + Declaration: &plan.ViewDeclaration{ + ColumnsConfig: map[string]*plan.ViewColumnConfig{ + "AUTHORIZED": { + DataType: "bool", + Tag: `internal:"true"`, + Groupable: &groupable, + }, + }, + }, }, }, ViewsByName: map[string]*plan.View{}, @@ -116,11 +334,179 @@ func TestLoader_LoadViews_Metadata(t *testing.T) { assert.Equal(t, "platform/items", actual.Module) require.NotNil(t, actual.AllowNulls) assert.True(t, *actual.AllowNulls) + assert.True(t, actual.Groupable) require.NotNil(t, actual.Selector) assert.Equal(t, "it", actual.Selector.Namespace) assert.True(t, actual.Selector.NoLimit) + require.NotNil(t, actual.Selector.Constraints) + assert.True(t, actual.Selector.Constraints.Limit) + assert.True(t, actual.Selector.Constraints.Criteria) + assert.True(t, actual.Selector.Constraints.Projection) + assert.True(t, actual.Selector.Constraints.OrderBy) + assert.True(t, actual.Selector.Constraints.Offset) + assert.Equal(t, []string{"*"}, actual.Selector.Constraints.Filterable) + assert.Equal(t, "ACCOUNT_ID", actual.Selector.Constraints.OrderByColumn["accountId"]) require.NotNil(t, actual.Schema) assert.Equal(t, "*ItemView", actual.Schema.DataType) + require.NotNil(t, actual.ColumnsConfig) + require.Contains(t, actual.ColumnsConfig, "AUTHORIZED") + require.NotNil(t, actual.ColumnsConfig["AUTHORIZED"].DataType) + assert.Equal(t, "bool", *actual.ColumnsConfig["AUTHORIZED"].DataType) + require.NotNil(t, actual.ColumnsConfig["AUTHORIZED"].Tag) + assert.Equal(t, `internal:"true"`, *actual.ColumnsConfig["AUTHORIZED"].Tag) + require.NotNil(t, actual.ColumnsConfig["AUTHORIZED"].Groupable) + assert.True(t, *actual.ColumnsConfig["AUTHORIZED"].Groupable) +} + +func TestLoader_LoadViews_SelectorLimitEnablesConstraint(t *testing.T) { + limit := 2 + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "district"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "cities", + Table: "CITY", + SelectorLimit: &limit, + SelectorNamespace: "ci", + SchemaType: "*CitiesView", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM CITY", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.Len(t, artifacts.Views, 1) + actual := artifacts.Views[0] + require.NotNil(t, actual.Selector) + require.NotNil(t, actual.Selector.Constraints) + assert.Equal(t, 2, actual.Selector.Limit) + assert.True(t, actual.Selector.Constraints.Limit) +} + +func TestCloneRelationView_SelectorLimitUsesSingleParentBatch(t *testing.T) { + ref, err := view.New("cities", "CITY") + require.NoError(t, err) + ref.Selector = &view.Config{ + Limit: 2, + Constraints: &view.Constraints{ + Limit: true, + }, + } + cloned := cloneRelationView(ref, view.View{}) + require.NotNil(t, cloned.Batch) + assert.Equal(t, 1, cloned.Batch.Size) +} + +func TestLoader_LoadViews_InfersColumnsFromBestSchemaType(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "user_acl"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "user_acl", + SchemaType: "*UserAclView", + FieldType: reflect.TypeOf([]fieldOnlyUserACLRow{}), + ElementType: nil, + SQL: "SELECT 1", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.Len(t, artifacts.Views, 1) + require.Len(t, artifacts.Views[0].Columns, 3) + assert.Equal(t, "UserID", artifacts.Views[0].Columns[0].Name) + assert.Equal(t, "IsReadOnly", artifacts.Views[0].Columns[1].Name) + assert.Equal(t, "Feature1", artifacts.Views[0].Columns[2].Name) +} + +func TestBindTemplateParameters_SkipsSelfViewParameter(t *testing.T) { + resource := &view.Resource{ + Parameters: state.Parameters{ + {Name: "Jwt", In: state.NewHeaderLocation("Authorization")}, + {Name: "VendorID", In: state.NewPathLocation("vendorID")}, + {Name: "Authorization", In: state.NewViewLocation("authorization")}, + {Name: "Auth", In: state.NewComponent("GET:/auth")}, + }, + Views: []*view.View{ + { + Name: "authorization", + Template: view.NewTemplate("SELECT Authorized", view.WithTemplateParameters( + &state.Parameter{Name: "Jwt", In: state.NewHeaderLocation("Authorization")}, + )), + }, + }, + } + + bindTemplateParameters(resource) + + require.Len(t, resource.Views, 1) + require.NotNil(t, resource.Views[0].Template) + var names []string + for _, param := range resource.Views[0].Template.Parameters { + names = append(names, param.Name) + } + assert.ElementsMatch(t, []string{"Jwt", "VendorID", "Auth"}, names) +} + +func TestLoader_LoadComponent_NormalizesViewInputSchemaFromResourceView(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/teams"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "user_team", + Table: "TEAM", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "UPDATE TEAM SET ACTIVE = false", + }, + { + Name: "TeamStats", + Table: "TEAM", + Cardinality: "many", + SchemaType: "*TeamStatsView", + FieldType: reflect.TypeOf([]struct { + ID int `sqlx:"name=ID"` + TeamMembers int `sqlx:"name=TEAM_MEMBERS"` + Name string `sqlx:"name=NAME"` + }{}), + SQL: "SELECT ID, 0 AS TEAM_MEMBERS, NAME FROM TEAM", + }, + }, + States: []*plan.State{ + {Parameter: state.Parameter{Name: "TeamIDs", In: state.NewQueryLocation("TeamIDs"), Schema: &state.Schema{DataType: "[]int"}}}, + {Parameter: state.Parameter{Name: "TeamStats", In: state.NewViewLocation("TeamStats")}}, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + param := component.InputParameters().Lookup("TeamStats") + require.NotNil(t, param) + require.NotNil(t, param.Schema) + assert.Equal(t, "TeamStatsView", param.Schema.Name) + assert.Equal(t, "*TeamStatsView", param.Schema.DataType) + assert.Equal(t, state.Many, param.Schema.Cardinality) } func TestLoader_LoadComponent(t *testing.T) { @@ -143,9 +529,20 @@ func TestLoader_LoadComponent(t *testing.T) { actualPlan.Directives = &dqlshape.Directives{ Meta: "docs/report.md", DefaultConnector: "analytics", + TemplateType: "patch", + Dest: "all.go", + InputDest: "input.go", + OutputDest: "output.go", + RouterDest: "router.go", + InputType: "CustomInput", + OutputType: "CustomOutput", Cache: &dqlshape.CacheDirective{ - Enabled: true, - TTL: "5m", + Enabled: true, + TTL: "5m", + Name: "aerospike", + Provider: "aerospike://127.0.0.1:3000/test", + Location: "${view.Name}", + TimeToLiveMs: 3600000, }, MCP: &dqlshape.MCPDirective{ Name: "report.list", @@ -155,7 +552,7 @@ func TestLoader_LoadComponent(t *testing.T) { } loader := New() - artifact, err := loader.LoadComponent(context.Background(), planned) + artifact, err := loader.LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) require.NoError(t, err) require.NotNil(t, artifact) require.NotNil(t, artifact.Resource) @@ -164,8 +561,12 @@ func TestLoader_LoadComponent(t *testing.T) { component, ok := ComponentFrom(artifact) require.True(t, ok) assert.Equal(t, "/v1/api/report", component.Name) - assert.Equal(t, "/v1/api/report", component.URI) + assert.Equal(t, "/v1/api/dev/report", component.URI) assert.Equal(t, "GET", component.Method) + require.Len(t, component.ComponentRoutes, 1) + assert.Equal(t, "Route", component.ComponentRoutes[0].FieldName) + assert.Equal(t, "/v1/api/dev/report", component.ComponentRoutes[0].RoutePath) + assert.Equal(t, "dev", component.ComponentRoutes[0].Connector) assert.Equal(t, "rows", component.RootView) assert.Equal(t, []string{"rows"}, component.Views) assert.Len(t, component.Input, 1) @@ -179,48 +580,92 @@ func TestLoader_LoadComponent(t *testing.T) { require.NotNil(t, component.Directives) assert.Equal(t, "docs/report.md", component.Directives.Meta) assert.Equal(t, "analytics", component.Directives.DefaultConnector) + assert.Equal(t, "patch", component.Directives.TemplateType) require.NotNil(t, component.Directives.Cache) assert.True(t, component.Directives.Cache.Enabled) assert.Equal(t, "5m", component.Directives.Cache.TTL) + assert.Equal(t, "aerospike", component.Directives.Cache.Name) require.NotNil(t, component.Directives.MCP) assert.Equal(t, "report.list", component.Directives.MCP.Name) + require.NotEmpty(t, artifact.Resource.CacheProviders) + assert.Equal(t, "aerospike", artifact.Resource.CacheProviders[0].Name) + assert.Equal(t, "aerospike://127.0.0.1:3000/test", artifact.Resource.CacheProviders[0].Provider) + assert.Equal(t, "${view.Name}", artifact.Resource.CacheProviders[0].Location) + assert.Equal(t, 3600000, artifact.Resource.CacheProviders[0].TimeToLiveMs) assert.True(t, component.ColumnsDiscovery) + require.NotNil(t, component.TypeSpecs) + require.NotNil(t, component.TypeSpecs["input"]) + assert.Equal(t, "CustomInput", component.TypeSpecs["input"].TypeName) + assert.Equal(t, "input.go", component.TypeSpecs["input"].Dest) + require.NotNil(t, component.TypeSpecs["output"]) + assert.Equal(t, "CustomOutput", component.TypeSpecs["output"].TypeName) + assert.Equal(t, "output.go", component.TypeSpecs["output"].Dest) + assert.Equal(t, "router.go", component.Directives.RouterDest) } -func TestLoader_LoadComponent_RelationFieldsPreserved(t *testing.T) { +func TestLoader_LoadComponent_UsesComponentRouteWhenSourceNameMissing(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ComponentRoutes, 1) + assert.Equal(t, "/v1/api/dev/report", component.URI) + assert.Equal(t, "/v1/api/dev/report", component.Name) + assert.Equal(t, "GET", component.Method) + assert.Equal(t, "Route", component.ComponentRoutes[0].FieldName) +} + +func TestLoader_LoadComponent_InheritsTypeContextPackageForNamedStateSchemas(t *testing.T) { planned := &shape.PlanResult{ - Source: &shape.Source{Name: "/v1/api/report"}, + Source: &shape.Source{Name: "patch_basic_one"}, Plan: &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "patch_basic_one", + DefaultPackage: "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", + }, Views: []*plan.View{ { - Path: "Rows", - Name: "rows", - Table: "REPORT", - Cardinality: "many", + Name: "foos", + Holder: "Foos", FieldType: reflect.TypeOf([]reportRow{}), ElementType: reflect.TypeOf(reportRow{}), - Relations: []*plan.Relation{ - { - Name: "detail", - Holder: "Detail", - Ref: "detail", - Table: "REPORT_DETAIL", - On: []*plan.RelationLink{ - { - ParentField: "ReportID", - ParentNamespace: "rows", - ParentColumn: "REPORT_ID", - RefField: "ID", - RefNamespace: "detail", - RefColumn: "ID", - }, - }, + Cardinality: string(state.Many), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: &state.Location{Kind: state.KindRequestBody, Name: ""}, + Schema: &state.Schema{ + Name: "Foos", + DataType: "Foos", + }, + }, + }, + { + Parameter: state.Parameter{ + Name: "Foos", + In: &state.Location{Kind: state.KindOutput, Name: "view"}, + Tag: `anonymous:"true"`, + Schema: &state.Schema{ + Name: "Foos", + DataType: "Foos", }, }, }, }, - ViewsByName: map[string]*plan.View{}, - ByPath: map[string]*plan.Field{}, }, } @@ -229,16 +674,2337 @@ func TestLoader_LoadComponent_RelationFieldsPreserved(t *testing.T) { require.NoError(t, err) component, ok := ComponentFrom(artifact) require.True(t, ok) - require.Len(t, component.ViewRelations, 1) - require.Len(t, component.ViewRelations[0].On, 1) - require.Len(t, component.ViewRelations[0].Of.On, 1) + require.Len(t, component.Input, 1) + require.Len(t, component.Output, 1) + assert.Equal(t, "patch_basic_one", component.Input[0].Schema.Package) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", component.Input[0].Schema.PackagePath) + assert.Equal(t, "patch_basic_one", component.Output[0].Schema.Package) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", component.Output[0].Schema.PackagePath) +} - parent := component.ViewRelations[0].On[0] - ref := component.ViewRelations[0].Of.On[0] - assert.Equal(t, "ReportID", parent.Field) - assert.Equal(t, "rows", parent.Namespace) - assert.Equal(t, "REPORT_ID", parent.Column) - assert.Equal(t, "ID", ref.Field) - assert.Equal(t, "detail", ref.Namespace) - assert.Equal(t, "ID", ref.Column) +func TestLoader_LoadComponent_DoesNotInheritTypeContextPackageForPrimitiveStateSchemas(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "district_pagination"}, + Plan: &plan.Result{ + TypeContext: &typectx.Context{ + DefaultPackage: "github.com/viant/datly/e2e/v1/shape/dev/district/pagination", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/district/pagination", + }, + Views: []*plan.View{ + { + Name: "districts", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Cardinality: string(state.Many), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "IDs", + In: state.NewQueryLocation("IDs"), + Schema: &state.Schema{DataType: "[]int"}, + }, + }, + { + Parameter: state.Parameter{ + Name: "Page", + In: state.NewQueryLocation("page"), + Schema: &state.Schema{DataType: "int"}, + }, + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.Input, 2) + assert.Empty(t, component.Input[0].Schema.Package) + assert.Empty(t, component.Input[0].Schema.PackagePath) + assert.Empty(t, component.Input[1].Schema.Package) + assert.Empty(t, component.Input[1].Schema.PackagePath) +} + +func TestLoader_LoadComponent_MaterializesAnonymousBodySchemaIntoResourceParameters(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "patch_basic_one"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "foos", + Holder: "Foos", + SchemaType: "*FoosView", + Cardinality: string(state.Many), + SQL: "SELECT * FROM FOOS", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + }, + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + param, err := artifact.Resource.LookupParameter("Foos") + require.NoError(t, err) + require.NotNil(t, param) + require.NotNil(t, param.Schema) + assert.Equal(t, "FoosView", param.Schema.Name) + assert.Equal(t, "*FoosView", param.Schema.DataType) +} + +func TestLoader_LoadComponent_InheritsTypeContextPackageForNamedViewSchemas(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "patch_basic_one"}, + Plan: &plan.Result{ + TypeContext: &typectx.Context{ + PackageName: "patch_basic_one", + DefaultPackage: "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", + PackagePath: "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", + }, + Views: []*plan.View{ + { + Name: "foos", + Holder: "Foos", + SchemaType: "*FoosView", + Cardinality: string(state.Many), + FieldType: nil, + ElementType: nil, + SQL: "SELECT * FROM FOOS", + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + root, err := artifact.Resource.Views.Index().Lookup("foos") + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Schema) + assert.Equal(t, "patch_basic_one", root.Schema.Package) + assert.Equal(t, "github.com/viant/datly/e2e/v1/shape/dev/events/patch_basic_one", root.Schema.PackagePath) +} + +func TestLoader_LoadComponent_PreservesComponentHolderTypes(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &typedRouteSource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ComponentRoutes, 1) + assert.Equal(t, reflect.TypeOf(typedRouteInput{}), component.ComponentRoutes[0].InputType) + assert.Equal(t, reflect.TypeOf(typedRouteOutput{}), component.ComponentRoutes[0].OutputType) + assert.Empty(t, component.ComponentRoutes[0].InputName) + assert.Empty(t, component.ComponentRoutes[0].OutputName) +} + +func TestLoader_LoadComponent_SynthesizesMutableHelpersForPatchBodyRoute(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "patch_basic_one"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "foos", + Holder: "Foos", + SchemaType: "*FoosView", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Cardinality: string(state.Many), + Table: "FOOS", + SQL: "SELECT * FROM FOOS", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: &state.Schema{ + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.One, + }, + }, + EmitOutput: true, + }, + }, + Components: []*plan.ComponentRoute{ + { + Method: "PATCH", + RoutePath: "/v1/api/shape/dev/basic/foos", + ViewName: "FoosView", + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + + root := lookupNamedResourceView(artifact.Resource, component.RootView) + require.NotNil(t, root) + assert.Equal(t, view.ModeExec, root.Mode) + require.NotNil(t, root.Template) + assert.True(t, root.Template.UseParameterStateType) + require.NotNil(t, root.Template.Parameters.Lookup("CurFoosId")) + require.NotNil(t, root.Template.Parameters.Lookup("CurFoos")) + assert.Contains(t, root.Template.Source, `$sequencer.Allocate("FOOS", $Unsafe.Foos, "Id")`) + assert.Contains(t, root.Template.Source, `#if($CurFoosById.HasKey($Unsafe.Foos.Id) == true)`) + assert.Contains(t, root.Template.Source, `$sql.Update($Unsafe.Foos, "FOOS");`) + assert.Contains(t, root.Template.Source, `$sql.Insert($Unsafe.Foos, "FOOS");`) + assert.Equal(t, state.Many, root.Template.Parameters.Lookup("CurFoos").Schema.Cardinality) + + require.Nil(t, component.InputParameters().Lookup("CurFoosId")) + require.Nil(t, component.InputParameters().Lookup("CurFoos")) + require.NotNil(t, artifact.Resource.Parameters.Lookup("CurFoosId")) + require.NotNil(t, artifact.Resource.Parameters.Lookup("CurFoos")) + assert.Equal(t, "*FoosView", artifact.Resource.Parameters.Lookup("CurFoosId").Schema.DataType) + require.NotNil(t, artifact.Resource.Parameters.Lookup("CurFoosId").Output) + assert.Equal(t, "structql", artifact.Resource.Parameters.Lookup("CurFoosId").Output.Name) + assert.Contains(t, artifact.Resource.Parameters.Lookup("CurFoosId").Output.Body, "SELECT ARRAY_AGG(Id) AS Values") + assert.Equal(t, state.Many, artifact.Resource.Parameters.Lookup("CurFoosId").Schema.Cardinality) + assert.Equal(t, state.One, artifact.Resource.Parameters.Lookup("CurFoosId").Output.Schema.Cardinality) + assert.Equal(t, state.Many, artifact.Resource.Parameters.Lookup("CurFoos").Schema.Cardinality) + require.Len(t, component.Output, 1) + require.Equal(t, "Foos", component.Output[0].Name) + require.Equal(t, state.KindRequestBody, component.Output[0].In.Kind) + + curFoos, err := artifact.Resource.View("CurFoos") + require.NoError(t, err) + require.NotNil(t, curFoos) + require.NotNil(t, curFoos.Template) + assert.Equal(t, "foos/cur_foos.sql", curFoos.Template.SourceURL) + require.True(t, curFoos.Template.UseParameterStateType) + require.True(t, curFoos.Template.DeclaredParametersOnly) + require.True(t, curFoos.Template.UseResourceParameterLookup) + require.NotNil(t, curFoos.Template.Parameters.Lookup("CurFoosId")) + require.Nil(t, curFoos.Template.Parameters.Lookup("Foos")) +} + +func TestLoader_LoadComponent_SynthesizesMutableHelpersForPatchManyBodyRoute(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "patch_basic_many"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "foos", + Holder: "Foos", + SchemaType: "*FoosView", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Cardinality: string(state.Many), + Table: "FOOS", + SQL: "SELECT * FROM FOOS", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: &state.Schema{ + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.Many, + }, + }, + EmitOutput: true, + }, + }, + Components: []*plan.ComponentRoute{ + { + Method: "PATCH", + RoutePath: "/v1/api/shape/dev/basic/foos-many", + ViewName: "FoosView", + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + + curFoosID := artifact.Resource.Parameters.Lookup("CurFoosId") + require.NotNil(t, curFoosID) + require.NotNil(t, curFoosID.Schema) + require.NotNil(t, curFoosID.Output) + require.NotNil(t, curFoosID.Output.Schema) + assert.Equal(t, state.Many, curFoosID.Schema.Cardinality) + assert.Equal(t, state.One, curFoosID.Output.Schema.Cardinality) + assert.Equal(t, "*FoosView", curFoosID.Schema.DataType) + assert.Contains(t, curFoosID.Output.Body, "SELECT ARRAY_AGG(Id) AS Values") + root := lookupNamedResourceView(artifact.Resource, "foos") + require.NotNil(t, root) + require.NotNil(t, root.Template) + assert.Contains(t, root.Template.Source, `$sequencer.Allocate("FOOS", $Unsafe.Foos, "Id")`) + assert.Contains(t, root.Template.Source, `#foreach($RecFoos in $Unsafe.Foos)`) + assert.Contains(t, root.Template.Source, `#if($CurFoosById.HasKey($RecFoos.Id) == true)`) + assert.Contains(t, root.Template.Source, `$sql.Update($RecFoos, "FOOS");`) + assert.Contains(t, root.Template.Source, `$sql.Insert($RecFoos, "FOOS");`) + require.NotNil(t, root.TableBatches) + assert.True(t, root.TableBatches["FOOS"]) +} + +func TestLoader_LoadComponent_DoesNotSynthesizeMutableHelpersForScalarBodyRoute(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "product_update"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "product_update", + Holder: "ProductUpdate", + SchemaType: "*ProductUpdateView", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Cardinality: string(state.Many), + Table: "PRODUCT", + SQL: "UPDATE PRODUCT SET STATUS = $Status WHERE ID IN ($Ids)", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Ids", + In: state.NewBodyLocation("Ids"), + Schema: &state.Schema{ + DataType: "int", + Cardinality: state.Many, + }, + }, + }, + { + Parameter: state.Parameter{ + Name: "Status", + In: state.NewBodyLocation("Status"), + Schema: &state.Schema{ + DataType: "int", + Cardinality: state.One, + }, + }, + }, + { + Parameter: state.Parameter{ + Name: "Records", + In: state.NewViewLocation("Records"), + Schema: &state.Schema{ + Name: "RecordsView", + DataType: "*RecordsView", + Cardinality: state.Many, + }, + }, + }, + }, + Components: []*plan.ComponentRoute{ + { + Method: "POST", + RoutePath: "/v1/api/shape/dev/auth/products", + ViewName: "ProductUpdateView", + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + + root := lookupNamedResourceView(artifact.Resource, component.RootView) + require.NotNil(t, root) + require.NotNil(t, root.Template) + require.Nil(t, root.Template.Parameters.Lookup("CurIdsId")) + require.Nil(t, artifact.Resource.Parameters.Lookup("CurIdsId")) + require.Nil(t, artifact.Resource.Parameters.Lookup("CurIds")) +} + +func TestLoader_LoadComponent_UserMetadataPreservesBitColumnsFromSchemaType(t *testing.T) { + projectRoot := t.TempDir() + err := os.WriteFile(filepath.Join(projectRoot, "go.mod"), []byte("module github.com/acme/app\n\ngo 1.23.0\n"), 0o644) + require.NoError(t, err) + packageDir := filepath.Join(projectRoot, "shape", "dev", "user", "mysql_boolean") + err = os.MkdirAll(packageDir, 0o755) + require.NoError(t, err) + sourcePath := filepath.Join(projectRoot, "routes", "dev", "user_metadata.dql") + err = os.MkdirAll(filepath.Dir(sourcePath), 0o755) + require.NoError(t, err) + err = os.WriteFile(sourcePath, []byte("SELECT * FROM USER_METADATA"), 0o644) + require.NoError(t, err) + typeFile := `package mysql_boolean + +import "github.com/viant/sqlx/types" + +type UserMetadataView struct { + Id int ` + "`sqlx:\"ID\"`" + ` + UserId *int ` + "`sqlx:\"USER_ID\"`" + ` + IsEnabled *types.BitBool ` + "`sqlx:\"IS_ENABLED\"`" + ` + IsActivated *types.BitBool ` + "`sqlx:\"IS_ACTIVATED\"`" + ` +} +` + err = os.WriteFile(filepath.Join(packageDir, "user_metadata.go"), []byte(typeFile), 0o644) + require.NoError(t, err) + + artifact, err := New().LoadComponent(context.Background(), &shape.PlanResult{ + Source: &shape.Source{Name: "user_metadata", Path: sourcePath}, + Plan: &plan.Result{ + TypeContext: &typectx.Context{ + PackagePath: "github.com/acme/app/shape/dev/user/mysql_boolean", + }, + Views: []*plan.View{ + { + Name: "user_metadata", + Table: "USER_METADATA", + SchemaType: "*UserMetadataView", + Cardinality: string(state.Many), + SQL: "SELECT user_metadata.* FROM (SELECT * FROM USER_METADATA t) user_metadata", + }, + }, + }, + }) + require.NoError(t, err) + root, err := artifact.Resource.Views.Index().Lookup("user_metadata") + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Schema) + require.NotNil(t, root.Schema.Type()) + + names := make([]string, 0, len(root.Columns)) + for _, column := range root.Columns { + if column == nil { + continue + } + names = append(names, column.Name) + } + assert.Contains(t, names, "IS_ENABLED") + assert.Contains(t, names, "IS_ACTIVATED") +} + +func TestLoader_LoadComponent_PreservesDynamicComponentHolderTypes(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &dynamicRouteSource{ + Route: xdatly.Component[any, any]{ + Inout: dynamicRouteInput{}, + Output: dynamicRouteOutput{}, + }, + }}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ComponentRoutes, 1) + assert.Equal(t, reflect.TypeOf(dynamicRouteInput{}), component.ComponentRoutes[0].InputType) + assert.Equal(t, reflect.TypeOf(dynamicRouteOutput{}), component.ComponentRoutes[0].OutputType) +} + +func TestLoader_LoadComponent_PreservesDynamicComponentHolderExplicitNames(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT" sql:"uri=testdata/report.sql"` + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET,input=ReportInput,output=ReportOutput"` + }{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ComponentRoutes, 1) + assert.Nil(t, component.ComponentRoutes[0].InputType) + assert.Nil(t, component.ComponentRoutes[0].OutputType) + assert.Equal(t, "ReportInput", component.ComponentRoutes[0].InputName) + assert.Equal(t, "ReportOutput", component.ComponentRoutes[0].OutputName) +} + +func TestLoader_LoadComponent_PreservesDynamicComponentHolderExplicitNamesFromRegistry(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(namedDynamicRouteInput{}), x.WithPkgPath("github.com/viant/datly/repository/shape/load"), x.WithName("ReportInput"))) + registry.Register(x.NewType(reflect.TypeOf(namedDynamicRouteOutput{}), x.WithPkgPath("github.com/viant/datly/repository/shape/load"), x.WithName("ReportOutput"))) + + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{ + Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET,input=ReportInput,output=ReportOutput"` + }{}, + TypeRegistry: registry, + }) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ComponentRoutes, 1) + assert.Equal(t, reflect.TypeOf(namedDynamicRouteInput{}), component.ComponentRoutes[0].InputType) + assert.Equal(t, reflect.TypeOf(namedDynamicRouteOutput{}), component.ComponentRoutes[0].OutputType) + assert.Equal(t, "ReportInput", component.ComponentRoutes[0].InputName) + assert.Equal(t, "ReportOutput", component.ComponentRoutes[0].OutputName) + require.Len(t, component.Input, 1) + assert.Equal(t, "name", component.Input[0].Name) + require.Len(t, component.Output, 2) + require.Len(t, artifact.Resource.Views, 1) + assert.Equal(t, "rows", artifact.Resource.Views[0].Name) +} + +func TestLoader_LoadComponent_ErrorsOnMultipleComponentRoutes(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT" sql:"uri=testdata/report.sql"` + RouteA struct{} `component:",path=/v1/api/dev/report-a,method=GET"` + RouteB struct{} `component:",path=/v1/api/dev/report-b,method=POST"` + }{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + _, err = loader.LoadComponent(context.Background(), planned) + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple component routes are not supported") +} + +func TestLoader_LoadComponent_RouterOnlySourceSynthesizesStatesAndViews(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &routerOnlySource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + assert.Equal(t, "/v1/api/dev/router-only", component.URI) + assert.Equal(t, "GET", component.Method) + require.Len(t, component.Input, 1) + assert.Equal(t, "id", component.Input[0].Name) + require.Len(t, component.Output, 2) + require.Len(t, artifact.Resource.Views, 1) + assert.Equal(t, "rows", artifact.Resource.Views[0].Name) +} + +func TestLoader_LoadComponent_SynthesizesStatesFromRouteContractsWhenPlanStatesAreEmpty(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "team"}, + Plan: &plan.Result{ + Components: []*plan.ComponentRoute{ + { + FieldName: "Team", + Name: "Team", + RoutePath: "/v1/api/dev/team/{teamID}", + Method: "DELETE", + InputType: reflect.TypeOf(typedTeamRouteInput{}), + OutputType: reflect.TypeOf(typedTeamRouteOutput{}), + ViewName: "Team", + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.Input, 1) + assert.Equal(t, "TeamID", component.Input[0].Name) + require.NotNil(t, component.Input[0].In) + assert.Equal(t, state.KindPath, component.Input[0].In.Kind) + assert.Equal(t, "teamID", component.Input[0].In.Name) +} + +func TestLoader_LoadComponent_CacheProviderDoesNotBindRootView(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/shape/dev/vendors/"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + SQL: "SELECT * FROM VENDOR", + }, + }, + Directives: &dqlshape.Directives{ + Cache: &dqlshape.CacheDirective{ + Enabled: true, + Name: "aerospike", + Provider: "aerospike://127.0.0.1:3000/test", + Location: "${view.Name}", + TimeToLiveMs: 3600000, + }, + }, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + require.Len(t, artifact.Resource.Views, 1) + require.NotEmpty(t, artifact.Resource.CacheProviders) + + root := artifact.Resource.Views[0] + assert.Nil(t, root.Cache) +} + +func TestLoader_LoadViews_DoesNotSeedPlaceholderColumnsFromLinkedType(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "districts"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "districts", + Table: "DISTRICT", + Cardinality: "many", + SchemaType: "*DistrictsView", + FieldType: reflect.TypeOf([]*placeholderDistrictRow{}), + ElementType: reflect.TypeOf(placeholderDistrictRow{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifact, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + require.Len(t, artifact.Views, 1) + assert.Empty(t, artifact.Views[0].Columns) +} + +func TestLoader_LoadViews_DefersMapBackedQuerySchemaType(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "cities"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "cities", + Table: "CITY", + Mode: string(view.ModeQuery), + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM CITY", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifact, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + require.Len(t, artifact.Views, 1) + assert.Nil(t, artifact.Views[0].Schema.Type()) +} + +func TestLoader_LoadViews_DefersPlaceholderStructQuerySchemaType(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "districts"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "districts", + Table: "DISTRICT", + Mode: string(view.ModeQuery), + Cardinality: "many", + FieldType: reflect.TypeOf([]placeholderDistrictRow{}), + ElementType: reflect.TypeOf(placeholderDistrictRow{}), + SQL: "SELECT t.* FROM DISTRICT t", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifact, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + require.Len(t, artifact.Views, 1) + assert.Nil(t, artifact.Views[0].Schema.Type()) +} + +func TestLoader_LoadComponent_DoesNotMaterializePlaceholderOutputViewSchema(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "districts"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "districts", + Table: "DISTRICT", + Mode: string(view.ModeQuery), + Cardinality: "many", + FieldType: reflect.TypeOf([]placeholderDistrictRow{}), + ElementType: reflect.TypeOf(placeholderDistrictRow{}), + SQL: "SELECT t.* FROM DISTRICT t", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Data", + In: state.NewOutputLocation("view"), + }, + }, + }, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.Output, 1) + require.NotNil(t, component.Output[0].Schema) + assert.Nil(t, component.Output[0].Schema.Type()) +} + +func TestLoader_LoadViews_PreservesChildSummaryGraph(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM VENDOR", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + {ParentColumn: "ID", RefColumn: "VENDOR_ID"}, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM PRODUCT", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) PROD_META GROUP BY VENDOR_ID", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 2) + + index := artifacts.Resource.Views.Index() + products, err := index.Lookup("products") + require.NoError(t, err) + require.NotNil(t, products) + require.NotNil(t, products.Template) + require.NotNil(t, products.Template.Summary) + assert.Contains(t, products.Template.Summary.Source, "TOTAL_PRODUCTS") + assert.Contains(t, products.Template.Summary.Source, "$View.products.SQL") +} + +func TestLoader_LoadViews_AttachesChildSummaryFieldToParentSchema(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(productsMetaSummaryRow{}), x.WithName("ProductsMetaView"))) + + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta", TypeRegistry: registry}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorSummaryParentRow{}), + ElementType: reflect.TypeOf(vendorSummaryParentRow{}), + SQL: "SELECT * FROM VENDOR", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + {ParentColumn: "ID", RefColumn: "VENDOR_ID"}, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorSummaryChildRow{}), + ElementType: reflect.TypeOf(vendorSummaryChildRow{}), + SQL: "SELECT * FROM PRODUCT", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) PROD_META GROUP BY VENDOR_ID", + SummaryName: "ProductsMeta", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 2) + + index := artifacts.Resource.Views.Index() + root, err := index.Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Schema) + compType := root.Schema.CompType() + require.NotNil(t, compType) + field, ok := compType.FieldByName("ProductsMeta") + require.True(t, ok) + assert.Equal(t, `json:",omitempty" yaml:",omitempty" sqlx:"-"`, string(field.Tag)) +} + +func TestLoader_LoadResource_AssignsSummarySchemas(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(metaSummaryRow{}), x.WithName("MetaView"))) + registry.Register(x.NewType(reflect.TypeOf(productsMetaSummaryRow{}), x.WithName("ProductsMetaView"))) + + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta", TypeRegistry: registry}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM VENDOR", + Summary: "SELECT COUNT(*) AS CNT FROM ($View.NonWindowSQL) t", + SummaryName: "Meta", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + {ParentColumn: "ID", RefColumn: "VENDOR_ID"}, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM PRODUCT", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) PROD_META GROUP BY VENDOR_ID", + SummaryName: "ProductsMeta", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Meta", + In: state.NewOutputLocation("summary"), + Schema: &state.Schema{DataType: "?"}, + }, + }, + }, + Components: []*plan.ComponentRoute{ + { + RoutePath: "/v1/api/dev/meta/vendors-nested", + Method: "GET", + ViewName: "vendor", + Name: "vendor", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadResource(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + + index := artifacts.Resource.Views.Index() + root, err := index.Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + require.NotNil(t, root.Template) + require.NotNil(t, root.Template.Summary) + require.NotNil(t, root.Template.Summary.Schema) + assert.Equal(t, "*load.metaSummaryRow", root.Template.Summary.Schema.Type().String()) + + products, err := index.Lookup("products") + require.NoError(t, err) + require.NotNil(t, products) + require.NotNil(t, products.Template) + require.NotNil(t, products.Template.Summary) + require.NotNil(t, products.Template.Summary.Schema) + assert.Equal(t, "*load.productsMetaSummaryRow", products.Template.Summary.Schema.Type().String()) + productsSummaryType := products.Template.Summary.Schema.Type() + if productsSummaryType.Kind() == reflect.Ptr { + productsSummaryType = productsSummaryType.Elem() + } + productsSummaryField, ok := productsSummaryType.FieldByName("VendorID") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf(int(0)), productsSummaryField.Type) + + metaParam, err := artifacts.Resource.LookupParameter("Meta") + require.NoError(t, err) + require.NotNil(t, metaParam) + require.NotNil(t, metaParam.Schema) + require.NotNil(t, metaParam.Schema.Type()) + assert.Equal(t, "*load.metaSummaryRow", metaParam.Schema.Type().String()) + + componentArtifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, componentArtifact) + component, ok := componentArtifact.Component.(*Component) + require.True(t, ok) + require.NotEmpty(t, component.Output) + require.NotNil(t, component.Output[0].Schema) + require.NotNil(t, component.Output[0].Schema.Type()) + assert.Equal(t, "*load.metaSummaryRow", component.Output[0].Schema.Type().String()) +} + +func TestRefineSummarySchemas_PrefersOwnerSchemaFieldTypeOverDiscoveredColumn(t *testing.T) { + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "products", + Schema: &state.Schema{ + Name: "ProductsView", + DataType: "*ProductsView", + Cardinality: state.Many, + }, + Columns: []*view.Column{ + {Name: "VENDOR_ID", DatabaseColumn: "VENDOR_ID", DataType: "int"}, + }, + Template: &view.Template{ + Summary: &view.TemplateSummary{ + Name: "ProductsMeta", + Schema: &state.Schema{ + Name: "ProductsMetaView", + DataType: "*ProductsMetaView", + Cardinality: state.One, + }, + }, + }, + }, + }, + } + resource.Views[0].Schema.SetType(reflect.TypeOf([]productsOwnerPointerRow{})) + resource.Views[0].Template.Summary.Schema.SetType(reflect.TypeOf(productsMetaSummaryRow{})) + + RefineSummarySchemas(resource) + + summaryType := resource.Views[0].Template.Summary.Schema.Type() + require.NotNil(t, summaryType) + if summaryType.Kind() == reflect.Ptr { + summaryType = summaryType.Elem() + } + field, ok := summaryType.FieldByName("VendorID") + require.True(t, ok) + assert.Equal(t, reflect.TypeOf((*int)(nil)), field.Type) +} + +func TestLoader_LoadResource_AttachesChildSummaryTemplateToRelationView(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(metaSummaryRow{}), x.WithName("MetaView"))) + registry.Register(x.NewType(reflect.TypeOf(productsMetaSummaryRow{}), x.WithName("ProductsMetaView"))) + + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta", TypeRegistry: registry}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorSummaryParentRow{}), + ElementType: reflect.TypeOf(vendorSummaryParentRow{}), + SQL: "SELECT * FROM VENDOR", + Summary: "SELECT COUNT(*) AS CNT FROM ($View.NonWindowSQL) t", + SummaryName: "Meta", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + {ParentColumn: "ID", RefColumn: "VENDOR_ID"}, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorSummaryChildRow{}), + ElementType: reflect.TypeOf(vendorSummaryChildRow{}), + SQL: "SELECT * FROM PRODUCT", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) PROD_META GROUP BY VENDOR_ID", + SummaryName: "ProductsMeta", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadResource(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + root, err := artifacts.Resource.Views.Index().Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + require.Len(t, root.With, 1) + require.NotNil(t, root.With[0].Of.Template) + require.NotNil(t, root.With[0].Of.Template.Summary) + require.NotNil(t, root.With[0].Of.Template.Summary.Schema) + assert.Equal(t, "ProductsMetaView", root.With[0].Of.Template.Summary.Schema.Name) + + products, err := artifacts.Resource.Views.Index().Lookup("products") + require.NoError(t, err) + require.NotNil(t, products) + require.NotNil(t, products.Template) + require.NotNil(t, products.Template.Summary) + require.NotNil(t, products.Template.Summary.Schema) + + relationSummaryType := root.With[0].Of.Template.Summary.Schema.Type() + require.NotNil(t, relationSummaryType) + if relationSummaryType.Kind() == reflect.Ptr { + relationSummaryType = relationSummaryType.Elem() + } + relationField, ok := relationSummaryType.FieldByName("VendorID") + require.True(t, ok) + assert.NotEqual(t, "true", relationField.Tag.Get("internal")) + + standaloneSummaryType := products.Template.Summary.Schema.Type() + require.NotNil(t, standaloneSummaryType) + if standaloneSummaryType.Kind() == reflect.Ptr { + standaloneSummaryType = standaloneSummaryType.Elem() + } + standaloneField, ok := standaloneSummaryType.FieldByName("VendorID") + require.True(t, ok) + assert.NotEqual(t, "true", standaloneField.Tag.Get("internal")) + assert.Equal(t, standaloneSummaryType, relationSummaryType) +} + +func TestLoader_LoadResource_MaterializesNamedResourceTypes(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(metaSummaryRow{}), x.WithName("MetaView"))) + registry.Register(x.NewType(reflect.TypeOf(productsMetaSummaryRow{}), x.WithName("ProductsMetaView"))) + + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta", TypeRegistry: registry}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorSummaryParentRow{}), + ElementType: reflect.TypeOf(vendorSummaryParentRow{}), + SQL: "SELECT * FROM VENDOR", + Summary: "SELECT COUNT(*) AS CNT FROM ($View.NonWindowSQL) t", + SummaryName: "Meta", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + {ParentColumn: "ID", RefColumn: "VENDOR_ID"}, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorSummaryChildRow{}), + ElementType: reflect.TypeOf(vendorSummaryChildRow{}), + SQL: "SELECT * FROM PRODUCT", + Summary: "SELECT VENDOR_ID, COUNT(*) AS TOTAL_PRODUCTS FROM ($View.products.SQL) PROD_META GROUP BY VENDOR_ID", + SummaryName: "ProductsMeta", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadResource(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + + var actual []string + for _, item := range artifacts.Resource.Types { + if item == nil { + continue + } + actual = append(actual, item.Name) + } + assert.ElementsMatch(t, []string{"VendorView", "MetaView", "ProductsView", "ProductsMetaView"}, actual) +} + +func TestLoader_LoadViews_PreservesSummarySourceURL(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQLURI: "vendor/vendor.sql", + SummaryURL: "vendor/vendor_summary.sql", + SummaryName: "Meta", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 1) + require.NotNil(t, artifacts.Views[0].Template) + require.NotNil(t, artifacts.Views[0].Template.Summary) + assert.Equal(t, "vendor/vendor_summary.sql", artifacts.Views[0].Template.Summary.SourceURL) +} + +func TestLoader_LoadResource_TypedViewDefinitionsPreferSchemaFieldsOverColumns(t *testing.T) { + type foosViewHas struct { + Id bool + Name bool + Quantity bool + } + type foosView struct { + Id int `sqlx:"ID" velty:"names=ID|Id"` + Name *string `sqlx:"NAME" velty:"names=NAME|Name"` + Quantity *int `sqlx:"QUANTITY" velty:"names=QUANTITY|Quantity"` + Has *foosViewHas `setMarker:"true" format:"-" sqlx:"-" diff:"-" json:"-" typeName:"FoosViewHas"` + } + + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "patch_basic_one"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "foos", + Table: "FOOS", + Cardinality: "many", + FieldType: reflect.TypeOf([]foosView{}), + ElementType: reflect.TypeOf(foosView{}), + SQL: "SELECT * FROM FOOS", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadResource(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + + var actual []string + for _, item := range artifacts.Resource.Types { + if item == nil || item.Name != "FoosView" { + continue + } + for _, field := range item.Fields { + actual = append(actual, field.Name) + } + } + assert.Equal(t, []string{"Id", "Name", "Quantity", "Has"}, actual) +} + +func TestLoader_LoadViews_InferRelationLinkFieldsFromSchemaTypes(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "vendor"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + SQL: "SELECT * FROM VENDOR", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + { + ParentNamespace: "vendor", + ParentColumn: "ID", + RefNamespace: "products", + RefColumn: "VENDOR_ID", + }, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorProductRow{}), + ElementType: reflect.TypeOf(vendorProductRow{}), + SQL: "SELECT * FROM PRODUCT", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + + root, err := artifacts.Resource.Views.Index().Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + require.Len(t, root.With, 1) + require.Len(t, root.With[0].On, 1) + require.Len(t, root.With[0].Of.On, 1) + + assert.Equal(t, "ID", root.With[0].On[0].Column) + assert.Equal(t, "ID", root.With[0].On[0].Field) + assert.Empty(t, root.With[0].On[0].Namespace) + assert.Equal(t, "VENDOR_ID", root.With[0].Of.On[0].Column) + assert.Equal(t, "VendorID", root.With[0].Of.On[0].Field) + assert.Empty(t, root.With[0].Of.On[0].Namespace) +} + +func TestLoader_LoadViews_FallsBackToColumnNamesForRelationLinkFields(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "vendor"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM VENDOR", + Relations: []*plan.Relation{ + { + Name: "products", + Parent: "vendor", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + { + ParentColumn: "ID", + RefColumn: "VENDOR_ID", + }, + }, + }, + }, + }, + { + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM PRODUCT", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifacts, err := New().LoadViews(context.Background(), planned) + require.NoError(t, err) + + root, err := artifacts.Resource.Views.Index().Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + require.Len(t, root.With, 1) + require.Len(t, root.With[0].On, 1) + require.Len(t, root.With[0].Of.On, 1) + + assert.Equal(t, "Id", root.With[0].On[0].Field) + assert.Equal(t, "VendorId", root.With[0].Of.On[0].Field) +} + +func TestLoader_LoadComponent_ConstDirectiveCreatesInternalConstParameter(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/shape/dev/vendors-env/"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + SQL: "SELECT * FROM VENDOR", + }, + }, + Directives: &dqlshape.Directives{ + Const: map[string]string{ + "Vendor": "VENDOR", + }, + }, + Const: map[string]string{ + "Vendor": "VENDOR", + }, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + require.NotEmpty(t, artifact.Resource.Parameters) + require.NotEmpty(t, artifact.Component) + + var constParam *state.Parameter + for _, item := range artifact.Resource.Parameters { + if item != nil && item.Name == "Vendor" && item.In != nil && item.In.Kind == state.KindConst { + constParam = item + break + } + } + require.NotNil(t, constParam) + assert.Equal(t, "VENDOR", constParam.Value) + assert.Equal(t, `internal:"true"`, constParam.Tag) + require.NotNil(t, constParam.Schema) + assert.Equal(t, "string", constParam.Schema.DataType) + assert.Equal(t, state.One, constParam.Schema.Cardinality) + + loaded, ok := ComponentFrom(artifact) + require.True(t, ok) + var constInput *plan.State + for _, item := range loaded.Input { + if item != nil && item.Name == "Vendor" && item.In != nil && item.In.Kind == state.KindConst { + constInput = item + break + } + } + require.NotNil(t, constInput) + assert.Equal(t, "VENDOR", constInput.Value) + assert.Equal(t, `internal:"true"`, constInput.Tag) + require.NotNil(t, constInput.Schema) + assert.Equal(t, "string", constInput.Schema.DataType) + assert.Equal(t, state.One, constInput.Schema.Cardinality) +} + +func TestLoader_LoadComponent_ViewKindStateIsInput(t *testing.T) { + required := true + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/auth/vendors/{vendorID}"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Jwt", + In: state.NewHeaderLocation("Authorization"), + Required: &required, + ErrorStatusCode: 401, + Schema: &state.Schema{DataType: "string"}, + }, + }, + { + Parameter: state.Parameter{ + Name: "Authorization", + In: state.NewViewLocation("Authorization"), + Required: &required, + ErrorStatusCode: 403, + Schema: &state.Schema{Cardinality: state.Many}, + }, + }, + { + Parameter: state.Parameter{ + Name: "VendorID", + In: state.NewPathLocation("vendorID"), + Required: &required, + Schema: &state.Schema{DataType: "int"}, + }, + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + + require.Len(t, component.Input, 3) + var hasViewInput bool + for _, input := range component.Input { + if input != nil && input.In != nil && input.In.Kind == state.KindView && input.Name == "Authorization" { + hasViewInput = true + assert.Equal(t, 403, input.ErrorStatusCode) + break + } + } + assert.True(t, hasViewInput) +} + +func TestResolveTypeSpecs_ViewOverridesAndInheritance(t *testing.T) { + result := &plan.Result{ + Directives: &dqlshape.Directives{Dest: "all.go"}, + Views: []*plan.View{ + { + Name: "vendor", + Path: "vendor", + Declaration: &plan.ViewDeclaration{ + Dest: "vendor.go", + TypeName: "Vendor", + }, + }, + {Name: "products", Path: "vendor.products"}, + }, + } + specs := resolveTypeSpecs(result) + require.NotNil(t, specs) + require.NotNil(t, specs["view:vendor"]) + assert.Equal(t, "Vendor", specs["view:vendor"].TypeName) + assert.Equal(t, "vendor.go", specs["view:vendor"].Dest) + require.NotNil(t, specs["view:products"]) + assert.Equal(t, "vendor.go", specs["view:products"].Dest) + assert.True(t, specs["view:products"].Inherited) +} + +func TestLoader_LoadComponent_RelationFieldsPreserved(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/report"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "Rows", + Name: "rows", + Table: "REPORT", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Relations: []*plan.Relation{ + { + Name: "detail", + Holder: "Detail", + Ref: "detail", + Table: "REPORT_DETAIL", + On: []*plan.RelationLink{ + { + ParentField: "ReportID", + ParentNamespace: "rows", + ParentColumn: "REPORT_ID", + RefField: "ID", + RefNamespace: "detail", + RefColumn: "ID", + }, + }, + }, + }, + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Len(t, component.ViewRelations, 1) + require.Len(t, component.ViewRelations[0].On, 1) + require.Len(t, component.ViewRelations[0].Of.On, 1) + + parent := component.ViewRelations[0].On[0] + ref := component.ViewRelations[0].Of.On[0] + assert.Equal(t, "ReportID", parent.Field) + assert.Equal(t, "rows", parent.Namespace) + assert.Equal(t, "REPORT_ID", parent.Column) + assert.Equal(t, "ID", ref.Field) + assert.Equal(t, "detail", ref.Namespace) + assert.Equal(t, "ID", ref.Column) +} + +func TestLoader_LoadComponent_ExecViewInputPreservesVeltyAliases(t *testing.T) { + required := false + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "user_team"}, + Plan: &plan.Result{ + Components: []*plan.ComponentRoute{ + { + Name: "user_team", + Method: "PUT", + Path: "/v1/api/shape/dev/teams", + ViewName: "user_team", + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "TeamIDs", + In: state.NewQueryLocation("TeamIDs"), + Required: &required, + Schema: &state.Schema{DataType: "[]int", Cardinality: state.Many}, + }, + }, + { + Parameter: state.Parameter{ + Name: "TeamStats", + In: state.NewViewLocation("TeamStats"), + Required: &required, + Schema: &state.Schema{Cardinality: state.Many}, + }, + }, + }, + Views: []*plan.View{ + { + Name: "user_team", + Table: "TEAM", + Cardinality: "many", + SQL: "UPDATE TEAM SET ACTIVE = false", + Mode: string(view.ModeExec), + FieldType: reflect.TypeOf([]struct { + Id int `sqlx:"ID"` + }{}), + ElementType: reflect.TypeOf(struct { + Id int `sqlx:"ID"` + }{}), + }, + { + Name: "TeamStats", + Table: "TEAM", + Cardinality: "many", + Mode: string(view.ModeQuery), + ColumnsDiscovery: true, + }, + }, + ByPath: map[string]*plan.Field{}, + ViewsByName: map[string]*plan.View{}, + }, + } + plannedResult, ok := plan.ResultFrom(planned) + require.True(t, ok) + require.Len(t, plannedResult.Views, 2) + plannedResult.Views[1].FieldType = reflect.TypeOf([]struct { + Id int `sqlx:"ID"` + TeamMembers int `sqlx:"TEAM_MEMBERS"` + Name *string `sqlx:"NAME"` + }{}) + plannedResult.Views[1].ElementType = plannedResult.Views[1].FieldType.Elem() + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + + resourceView, err := artifact.Resource.View("TeamStats") + require.NoError(t, err) + require.NotNil(t, resourceView) + require.NotNil(t, resourceView.Schema) + resourceType := resourceView.Schema.Type() + require.NotNil(t, resourceType) + if resourceType.Kind() == reflect.Slice { + resourceType = resourceType.Elem() + } + if resourceType.Kind() == reflect.Ptr { + resourceType = resourceType.Elem() + } + require.Equal(t, reflect.Struct, resourceType.Kind()) + + var inputParam *plan.State + for _, item := range component.Input { + if item != nil && strings.EqualFold(item.Name, "TeamStats") { + inputParam = item + break + } + } + require.NotNil(t, inputParam) + require.NotNil(t, inputParam.Schema) + inputType := inputParam.Schema.Type() + require.NotNil(t, inputType) + if inputType.Kind() == reflect.Slice { + inputType = inputType.Elem() + } + if inputType.Kind() == reflect.Ptr { + inputType = inputType.Elem() + } + require.Equal(t, reflect.Struct, inputType.Kind()) + idField, ok := inputType.FieldByName("Id") + require.True(t, ok) + assert.Equal(t, "names=ID|Id", idField.Tag.Get("velty")) +} + +func TestLoader_LoadComponent_AttachesRelationTreeByParent(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/tree"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "root", + Name: "root", + Table: "ROOT", + Cardinality: "many", + FieldType: reflect.TypeOf([]relationTreeRoot{}), + ElementType: reflect.TypeOf(relationTreeRoot{}), + Relations: []*plan.Relation{ + { + Name: "child", + Parent: "root", + Holder: "Child", + Ref: "child", + Table: "CHILD", + On: []*plan.RelationLink{ + { + ParentNamespace: "root", + ParentColumn: "ID", + RefNamespace: "child", + RefField: "RootID", + RefColumn: "RootID", + }, + }, + }, + { + Name: "grand_child", + Parent: "child", + Holder: "GrandChild", + Ref: "grand_child", + Table: "GRAND_CHILD", + On: []*plan.RelationLink{ + { + ParentNamespace: "child", + ParentColumn: "ID", + RefNamespace: "grand_child", + RefField: "ChildID", + RefColumn: "ChildID", + }, + }, + }, + }, + }, + { + Path: "child", + Name: "child", + Table: "CHILD", + Cardinality: "many", + FieldType: reflect.TypeOf([]relationTreeChild{}), + ElementType: reflect.TypeOf(relationTreeChild{}), + }, + { + Path: "grand_child", + Name: "grand_child", + Table: "GRAND_CHILD", + Cardinality: "many", + FieldType: reflect.TypeOf([]relationTreeGrandChild{}), + ElementType: reflect.TypeOf(relationTreeGrandChild{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + + index := artifact.Resource.Views.Index() + root, err := index.Lookup("root") + require.NoError(t, err) + require.NotNil(t, root) + require.Len(t, root.With, 1) + assert.Equal(t, "child", root.With[0].Of.View.Ref) + + child, err := index.Lookup("child") + require.NoError(t, err) + require.NotNil(t, child) + require.Len(t, child.With, 1) + assert.Equal(t, "grand_child", child.With[0].Of.View.Ref) +} + +func TestLoader_LoadComponent_AugmentsRelationHolderField(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/districts"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "districts", + Name: "districts", + Table: "DISTRICT", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Relations: []*plan.Relation{ + { + Name: "cities", + Parent: "districts", + Holder: "Cities", + Ref: "cities", + Table: "CITY", + On: []*plan.RelationLink{ + { + ParentField: "ID", + ParentColumn: "ID", + RefField: "DistrictID", + RefColumn: "DistrictID", + }, + }, + }, + }, + }, + { + Path: "cities", + Name: "cities", + Table: "CITY", + Cardinality: "many", + FieldType: reflect.TypeOf([]cityRow{}), + ElementType: reflect.TypeOf(cityRow{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + + index := artifact.Resource.Views.Index() + root, err := index.Lookup("districts") + require.NoError(t, err) + require.NotNil(t, root) + compType := root.ComponentType() + require.NotNil(t, compType) + field, ok := compType.FieldByName("Cities") + require.True(t, ok) + assert.Equal(t, reflect.Slice, field.Type.Kind()) + assert.Equal(t, reflect.Ptr, field.Type.Elem().Kind()) + assert.Equal(t, "cityRow", field.Type.Elem().Elem().Name()) + assert.Contains(t, string(field.Tag), `view:",table=CITY"`) + assert.Contains(t, string(field.Tag), `on:"ID:ID=DistrictID:DistrictID"`) +} + +func TestLoader_LoadComponent_AugmentsRelationHolderField_ForMapBackedChildView(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/districts"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "districts", + Name: "districts", + Table: "DISTRICT", + Cardinality: "many", + FieldType: reflect.TypeOf([]placeholderDistrictRow{}), + ElementType: reflect.TypeOf(placeholderDistrictRow{}), + Relations: []*plan.Relation{ + { + Name: "cities", + Parent: "districts", + Holder: "Cities", + Ref: "cities", + Table: "CITY", + On: []*plan.RelationLink{ + { + ParentField: "ID", + ParentColumn: "ID", + RefField: "DistrictID", + RefColumn: "DISTRICT_ID", + }, + }, + }, + }, + }, + { + Path: "cities", + Name: "cities", + Table: "CITY", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + artifact, err := New().LoadComponent(context.Background(), planned, shape.WithLoadTypeContextPackages(true)) + require.NoError(t, err) + require.NotNil(t, artifact) + + root, err := artifact.Resource.Views.Index().Lookup("districts") + require.NoError(t, err) + require.NotNil(t, root) + + compType := root.ComponentType() + require.NotNil(t, compType) + + field, ok := compType.FieldByName("Cities") + require.True(t, ok) + assert.Equal(t, reflect.Slice, field.Type.Kind()) + assert.Equal(t, reflect.Struct, field.Type.Elem().Kind()) + assert.Contains(t, string(field.Tag), `view:",table=CITY"`) +} + +func TestLoader_LoadComponent_AugmentsResolvedParentRelationHolderField(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/vendor-details"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "wrapper", + Name: "wrapper", + Table: "WRAPPER", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Relations: []*plan.Relation{ + { + Name: "products", + Holder: "Products", + Ref: "products", + Table: "PRODUCT", + On: []*plan.RelationLink{ + { + ParentNamespace: "vendor", + ParentField: "ID", + ParentColumn: "ID", + RefNamespace: "products", + RefField: "VendorID", + RefColumn: "VENDOR_ID", + }, + }, + }, + }, + }, + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + }, + { + Path: "products", + Name: "products", + Table: "PRODUCT", + Cardinality: "many", + FieldType: reflect.TypeOf([]vendorProductRow{}), + ElementType: reflect.TypeOf(vendorProductRow{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + index := artifact.Resource.Views.Index() + vendor, err := index.Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, vendor) + compType := vendor.ComponentType() + require.NotNil(t, compType) + field, ok := compType.FieldByName("Products") + require.True(t, ok) + assert.Equal(t, reflect.Slice, field.Type.Kind()) + assert.Equal(t, reflect.Ptr, field.Type.Elem().Kind()) + assert.Equal(t, "vendorProductRow", field.Type.Elem().Elem().Name()) +} + +func TestLoader_LoadComponent_InfersOneToOneOnSameTableJoin(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/vendor-details"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "wrapper", + Name: "wrapper", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Relations: []*plan.Relation{ + { + Name: "vendor", + Parent: "wrapper", + Holder: "Vendor", + Ref: "vendor", + Table: "VENDOR", + On: []*plan.RelationLink{ + { + ParentNamespace: "wrapper", + ParentColumn: "ID", + RefNamespace: "vendor", + RefColumn: "ID", + }, + }, + }, + { + Name: "setting", + Parent: "wrapper", + Holder: "Setting", + Ref: "setting", + Table: "T", + On: []*plan.RelationLink{ + { + ParentNamespace: "wrapper", + ParentColumn: "ID", + RefNamespace: "setting", + RefColumn: "ID", + }, + }, + }, + }, + }, + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }, + { + Path: "setting", + Name: "setting", + Table: "T", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + + index := artifact.Resource.Views.Index() + root, err := index.Lookup("wrapper") + require.NoError(t, err) + require.NotNil(t, root) + require.Len(t, root.With, 2) + + assert.Equal(t, state.One, root.With[0].Cardinality) + assert.Equal(t, state.Many, root.With[1].Cardinality) +} + +func TestLoader_LoadComponent_IncludesComponentStateInInput(t *testing.T) { + required := true + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/vendor"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Auth", + In: state.NewComponent("GET:/v1/api/dev/auth"), + Required: &required, + Schema: &state.Schema{DataType: "*Output", Package: "auth"}, + }, + }, + }, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + require.Len(t, component.Input, 1) + assert.Equal(t, state.KindComponent, component.Input[0].In.Kind) + assert.Equal(t, "Auth", component.Input[0].Name) + require.NotNil(t, component.Input[0].Schema) + assert.Equal(t, reflect.TypeOf((*interface{})(nil)).Elem(), component.Input[0].Schema.Type()) +} + +func TestLoader_LoadComponent_MaterializesOutputStatusSchema(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/user"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "user", + Name: "user", + Table: "USER", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Status", + In: state.NewOutputLocation("status"), + Tag: `anonymous:"true"`, + }, + }, + }, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + + param, err := artifact.Resource.LookupParameter("Status") + require.NoError(t, err) + require.NotNil(t, param) + require.NotNil(t, param.Schema) + require.NotNil(t, param.Schema.Type()) + assert.Equal(t, "Status", param.Schema.Type().Name()) +} + +func TestLoader_LoadComponent_PreservesExplicitOutputViewCardinality(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/auth/user-acl"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "user_acl", + Name: "user_acl", + Table: "USER_ACL", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Data", + In: state.NewOutputLocation("view"), + Schema: &state.Schema{ + Cardinality: state.One, + }, + }, + }, + }, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + require.Len(t, component.Output, 1) + require.NotNil(t, component.Output[0].Schema) + assert.Equal(t, state.One, component.Output[0].Schema.Cardinality) +} + +func TestLoader_LoadComponent_RequiredViewInputDefaultsToOneCardinality(t *testing.T) { + required := true + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/auth/vendor"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "Authorization", + Name: "Authorization", + Table: "AUTH", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + }, + }, + States: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Authorization", + In: state.NewViewLocation("Authorization"), + Required: &required, + }, + }, + }, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.NotNil(t, component) + require.Len(t, component.Input, 1) + require.NotNil(t, component.Input[0].Schema) + assert.Equal(t, state.One, component.Input[0].Schema.Cardinality) +} + +func TestLoader_LoadComponent_UsesPlannedRefCardinality(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/vendor-meta"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "vendor", + Name: "vendor", + Table: "VENDOR", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Relations: []*plan.Relation{ + { + Name: "products_meta", + Parent: "vendor", + Holder: "ProductsMeta", + Ref: "products_meta", + Table: "PRODUCT", + On: []*plan.RelationLink{ + { + ParentNamespace: "vendor", + ParentColumn: "ID", + RefNamespace: "products_meta", + RefColumn: "VENDOR_ID", + }, + }, + }, + }, + }, + { + Path: "products_meta", + Name: "products_meta", + Table: "PRODUCT", + Cardinality: "one", + FieldType: reflect.TypeOf(map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + + index := artifact.Resource.Views.Index() + root, err := index.Lookup("vendor") + require.NoError(t, err) + require.NotNil(t, root) + require.Len(t, root.With, 1) + assert.Equal(t, state.One, root.With[0].Cardinality) +} + +func TestLoader_LoadComponent_SynthesizesRootViewFromComponentRoute(t *testing.T) { + baseDir := t.TempDir() + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(reportRow{}), x.WithPkgPath("example.com/routes"), x.WithName("ReportView"))) + + artifact, err := New().LoadComponent(context.Background(), &shape.PlanResult{ + Source: &shape.Source{ + Path: filepath.Join(baseDir, "router.go"), + TypeRegistry: registry, + }, + Plan: &plan.Result{ + Components: []*plan.ComponentRoute{{ + Name: "Report", + RoutePath: "/v1/api/report", + Method: "DELETE", + Connector: "dev", + ViewName: "example.com/routes.ReportView", + SourceURL: "report/report.sql", + }}, + }, + }) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + require.Equal(t, "Report", component.RootView) + require.Len(t, artifact.Resource.Views, 1) + require.Equal(t, "Report", artifact.Resource.Views[0].Name) + require.NotNil(t, artifact.Resource.Views[0].Template) + require.Equal(t, filepath.Join(baseDir, "report", "report.sql"), artifact.Resource.Views[0].Template.SourceURL) +} + +func TestLoader_LoadComponent_AllowsViewlessComponentRoute(t *testing.T) { + artifact, err := New().LoadComponent(context.Background(), &shape.PlanResult{ + Source: &shape.Source{Name: "delete_team"}, + Plan: &plan.Result{ + Components: []*plan.ComponentRoute{ + { + Name: "Team", + Method: "DELETE", + RoutePath: "/v1/api/dev/team/{teamID}", + Connector: "dev", + }, + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, artifact) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + assert.Equal(t, "DELETE", component.Method) + assert.Equal(t, "/v1/api/dev/team/{teamID}", component.URI) + assert.Empty(t, artifact.Resource.Views) +} + +func TestLoader_LoadComponent_QuerySelectorHolder(t *testing.T) { + scanned, err := scan.New().Scan(context.Background(), &shape.Source{Struct: &selectorHolderSource{}}) + require.NoError(t, err) + planned, err := plan.New().Plan(context.Background(), scanned) + require.NoError(t, err) + + artifact, err := New().LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := ComponentFrom(artifact) + require.True(t, ok) + + require.Equal(t, []string{"fields", "page"}, component.QuerySelectors["rows"]) + fields := component.InputParameters().Lookup("fields") + require.NotNil(t, fields) + assert.Equal(t, state.KindQuery, fields.In.Kind) + assert.Equal(t, "_fields", fields.In.Name) + page := component.InputParameters().Lookup("page") + require.NotNil(t, page) + assert.Equal(t, "_page", page.In.Name) } diff --git a/repository/shape/load/model.go b/repository/shape/load/model.go index a05f2287d..f091ee336 100644 --- a/repository/shape/load/model.go +++ b/repository/shape/load/model.go @@ -1,11 +1,15 @@ package load import ( + "reflect" + "github.com/viant/datly/repository/shape" dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/xreflect" ) // Component is a shape-loaded runtime-neutral component artifact. @@ -14,6 +18,7 @@ type Component struct { Name string URI string Method string + ComponentRoutes []*plan.ComponentRoute RootView string Views []string Relations []*plan.Relation @@ -23,7 +28,9 @@ type Component struct { Predicates map[string][]*plan.ViewPredicate TypeContext *typectx.Context Directives *dqlshape.Directives + Report *dqlshape.ReportDirective ColumnsDiscovery bool + TypeSpecs map[string]*TypeSpec Input []*plan.State Output []*plan.State @@ -32,9 +39,76 @@ type Component struct { Other []*plan.State } +type TypeRole string + +const ( + TypeRoleInput TypeRole = "input" + TypeRoleOutput TypeRole = "output" + TypeRoleView TypeRole = "view" +) + +type TypeSpec struct { + Key string + Role TypeRole + Alias string + TypeName string + Dest string + Inherited bool + Source string +} + // ShapeSpecKind implements shape.ComponentSpec. func (c *Component) ShapeSpecKind() string { return "component" } +// InputParameters returns input states as state.Parameters for type generation. +func (c *Component) InputParameters() state.Parameters { + if c == nil { + return nil + } + var result state.Parameters + for _, s := range c.Input { + if s != nil { + p := s.Parameter + result = append(result, &p) + } + } + return result +} + +// OutputParameters returns output states as state.Parameters for type generation. +func (c *Component) OutputParameters() state.Parameters { + if c == nil { + return nil + } + var result state.Parameters + for _, s := range c.Output { + if s != nil { + p := s.Parameter + result = append(result, &p) + } + } + return result +} + +// InputReflectType builds the Input struct reflect.Type using state.Parameters.ReflectType. +// This produces the same struct shape as the legacy codegen (with parameter tags, Has markers, etc.). +func (c *Component) InputReflectType(pkgPath string, lookupType xreflect.LookupType, opts ...state.ReflectOption) (reflect.Type, error) { + params := c.InputParameters() + if len(params) == 0 { + return nil, nil + } + return params.ReflectType(pkgPath, lookupType, opts...) +} + +// OutputReflectType builds the Output struct reflect.Type using state.Parameters.ReflectType. +func (c *Component) OutputReflectType(pkgPath string, lookupType xreflect.LookupType, opts ...state.ReflectOption) (reflect.Type, error) { + params := c.OutputParameters() + if len(params) == 0 { + return nil, nil + } + return params.ReflectType(pkgPath, lookupType, opts...) +} + // ComponentFrom extracts the typed component from a ComponentArtifact. // Returns (nil, false) when a is nil or contains an unexpected concrete type. func ComponentFrom(a *shape.ComponentArtifact) (*Component, bool) { diff --git a/repository/shape/load/model_test.go b/repository/shape/load/model_test.go new file mode 100644 index 000000000..a1c13bb5a --- /dev/null +++ b/repository/shape/load/model_test.go @@ -0,0 +1,94 @@ +package load + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view/state" + "github.com/viant/xreflect" +) + +type packageScopedPatchFoos struct { + ID int + Name *string + Quantity *int +} + +func TestComponent_InputReflectType_UsesComponentPackageForPatchHelpers(t *testing.T) { + const pkgPath = "github.com/viant/datly/e2e/v1/shape/dev/generate_patch_basic_one" + + types := xreflect.NewTypes() + require.NoError(t, types.Register("Foos", + xreflect.WithPackage(pkgPath), + xreflect.WithReflectType(reflect.TypeOf(packageScopedPatchFoos{})), + )) + + component := &Component{ + RootView: "Foos", + TypeContext: &typectx.Context{ + PackagePath: pkgPath, + PackageName: "generate_patch_basic_one", + }, + Input: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "Foos"}, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Tag: `codec:"structql,uri=foos/cur_foos_id.sql"`, + Schema: state.NewSchema(reflect.TypeOf(&struct { + Values []int + }{})), + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`, + Schema: &state.Schema{Name: "Foos", Cardinality: state.Many}, + }, + }, + }, + } + + rType, err := component.InputReflectType(pkgPath, types.Lookup, state.WithSetMarker(), state.WithTypeName("Input")) + require.NoError(t, err) + require.NotNil(t, rType) + + foosField, ok := rType.FieldByName("Foos") + require.True(t, ok) + assert.Equal(t, "packageScopedPatchFoos", namedType(foosField.Type).Name()) + + curFoosField, ok := rType.FieldByName("CurFoos") + require.True(t, ok) + assert.Equal(t, reflect.Slice, curFoosField.Type.Kind()) + assert.Equal(t, "packageScopedPatchFoos", namedType(curFoosField.Type.Elem()).Name()) + + curFoosIDField, ok := rType.FieldByName("CurFoosId") + require.True(t, ok) + assert.Equal(t, reflect.Ptr, curFoosIDField.Type.Kind()) + assert.Equal(t, reflect.Struct, curFoosIDField.Type.Elem().Kind()) + valuesField, ok := curFoosIDField.Type.Elem().FieldByName("Values") + require.True(t, ok) + assert.Equal(t, reflect.Slice, valuesField.Type.Kind()) + assert.Equal(t, reflect.Int, valuesField.Type.Elem().Kind()) +} + +func namedType(rType reflect.Type) reflect.Type { + for rType != nil && (rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice) { + rType = rType.Elem() + } + return rType +} diff --git a/repository/shape/model.go b/repository/shape/model.go index 88c8da537..8f5f97335 100644 --- a/repository/shape/model.go +++ b/repository/shape/model.go @@ -1,6 +1,8 @@ package shape import ( + "os" + "path/filepath" "reflect" "github.com/viant/datly/view" @@ -28,6 +30,26 @@ type Source struct { DQL string } +func (s *Source) BaseDir() string { + if s == nil { + return "" + } + location := filepath.Clean(s.Path) + if location == "" || location == "." { + return "" + } + if info, err := os.Stat(location); err == nil { + if info.IsDir() { + return location + } + return filepath.Dir(location) + } + if ext := filepath.Ext(location); ext != "" { + return filepath.Dir(location) + } + return location +} + // ScanSpec is implemented by every scan-pipeline descriptor result. // The sole production implementation is *scan.Result. type ScanSpec interface { @@ -67,6 +89,11 @@ type ViewArtifacts struct { Views view.Views } +// ResourceArtifacts is the runtime resource payload produced by Loader. +type ResourceArtifacts struct { + Resource *view.Resource +} + // ComponentArtifact is the runtime component payload produced by Loader. type ComponentArtifact struct { Resource *view.Resource diff --git a/repository/shape/options.go b/repository/shape/options.go index 27b970fae..dbd0529d8 100644 --- a/repository/shape/options.go +++ b/repository/shape/options.go @@ -238,3 +238,13 @@ func WithInferTypeContextDefaults(enabled bool) CompileOption { o.InferTypeContext = &enabled } } + +// WithLinkedTypes enables/disables linked Go type support during compile. +func WithLinkedTypes(enabled bool) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.UseLinkedTypes = &enabled + } +} diff --git a/repository/shape/parity_test.go b/repository/shape/parity_test.go index 8041328b1..4703d7d0d 100644 --- a/repository/shape/parity_test.go +++ b/repository/shape/parity_test.go @@ -75,6 +75,24 @@ func TestEngineParity_StructPipeline(t *testing.T) { assert.Equal(t, reflect.TypeOf(mv.Schema.CompType()), reflect.TypeOf(ev.Schema.CompType())) } +func TestEngineParity_LoadResource(t *testing.T) { + source := &paritySource{} + engine := shape.New( + shape.WithName("/v1/api/parity"), + shape.WithScanner(shapeScan.New()), + shape.WithPlanner(shapePlan.New()), + shape.WithLoader(shapeLoad.New()), + ) + + artifact, err := engine.LoadResource(context.Background(), source) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + require.Len(t, artifact.Resource.Views, 1) + assert.Equal(t, "rows", artifact.Resource.Views[0].Name) + assert.Equal(t, "REPORT", artifact.Resource.Views[0].Table) +} + func TestEngineParity_Component_SourceTagFieldJoin(t *testing.T) { source := &parityJoinSource{} scanner := shapeScan.New() diff --git a/repository/shape/plan/model.go b/repository/shape/plan/model.go index ffb295f28..38eaca97b 100644 --- a/repository/shape/plan/model.go +++ b/repository/shape/plan/model.go @@ -20,13 +20,35 @@ type Result struct { Views []*View ViewsByName map[string]*View States []*State + Components []*ComponentRoute Types []*Type ColumnsDiscovery bool + Const map[string]string TypeContext *typectx.Context Directives *dqlshape.Directives Diagnostics []*dqlshape.Diagnostic } +type ComponentRoute struct { + Path string + FieldName string + Type reflect.Type + InputType reflect.Type + OutputType reflect.Type + InputName string + OutputName string + ViewName string + SourceURL string + SummaryURL string + Name string + RoutePath string + Method string + Connector string + Marshaller string + Handler string + Report *dqlshape.ReportDirective +} + // Type is normalized type metadata collected during compile. type Type struct { Name string @@ -61,14 +83,26 @@ type View struct { SQL string SQLURI string Summary string + SummaryURL string + SummaryName string Relations []*Relation Holder string - AllowNulls *bool - SelectorNamespace string - SelectorNoLimit *bool - SchemaType string - ColumnsDiscovery bool + AllowNulls *bool + Groupable *bool + SelectorNamespace string + SelectorLimit *int + SelectorNoLimit *bool + SelectorCriteria *bool + SelectorProjection *bool + SelectorOrderBy *bool + SelectorOffset *bool + SelectorPage *bool + SelectorFilterable []string + SelectorOrderByColumns map[string]string + SchemaType string + ColumnsDiscovery bool + Self *SelfReference Cardinality string ElementType reflect.Type @@ -79,6 +113,8 @@ type View struct { // ViewDeclaration captures declaration options used to derive a view from DQL directives. type ViewDeclaration struct { Tag string + TypeName string + Dest string Codec string CodecArgs []string HandlerName string @@ -97,6 +133,7 @@ type ViewDeclaration struct { Async bool Output bool Predicates []*ViewPredicate + ColumnsConfig map[string]*ViewColumnConfig } // ViewPredicate captures WithPredicate / EnsurePredicate metadata. @@ -107,16 +144,25 @@ type ViewPredicate struct { Arguments []string } +// ViewColumnConfig captures declaration-level per-column overrides. +type ViewColumnConfig struct { + DataType string + Tag string + Groupable *bool +} + // Relation is normalized relation metadata extracted from DQL joins. type Relation struct { - Name string - Holder string - Ref string - Table string - Kind string - Raw string - On []*RelationLink - Warnings []string + Name string + Parent string + Holder string + Ref string + Table string + Kind string + Raw string + ColumnsConfig map[string]*ViewColumnConfig + On []*RelationLink + Warnings []string } // RelationLink represents one parent/ref join predicate. @@ -130,11 +176,19 @@ type RelationLink struct { Expression string } +// SelfReference captures self-join tree metadata parsed from DQL. +type SelfReference struct { + Holder string + Child string + Parent string +} + // State is a normalized parameter field plan. type State struct { state.Parameter `yaml:",inline"` QuerySelector string OutputDataType string + EmitOutput bool } func (s *State) KindString() string { diff --git a/repository/shape/plan/planner.go b/repository/shape/plan/planner.go index f78b4bcc1..2340c185c 100644 --- a/repository/shape/plan/planner.go +++ b/repository/shape/plan/planner.go @@ -10,6 +10,7 @@ import ( metakeys "github.com/viant/datly/repository/locator/meta/keys" outputkeys "github.com/viant/datly/repository/locator/output/keys" "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/scan" "github.com/viant/datly/view/state" ) @@ -61,11 +62,16 @@ func (p *Planner) Plan(ctx context.Context, scanned *shape.ScanResult, _ ...shap result.ViewsByName[v.Name] = v } } + assignNestedRelationParents(result.Views) for _, item := range scanResult.StateFields { result.States = append(result.States, normalizeState(item)) } + for _, item := range scanResult.ComponentFields { + result.Components = append(result.Components, normalizeComponent(item)) + } + return &shape.PlanResult{Source: scanned.Source, Plan: result}, nil } @@ -85,10 +91,43 @@ func normalizeView(field *scan.Field) *View { result.Partitioner = tag.View.PartitionerType result.PartitionedConcurrency = tag.View.PartitionedConcurrency result.RelationalConcurrency = tag.View.RelationalConcurrency + result.Groupable = tag.View.Groupable + result.SelectorNamespace = strings.TrimSpace(tag.View.SelectorNamespace) + result.SelectorLimit = tag.View.Limit + if tag.View.Limit != nil { + noLimit := *tag.View.Limit == 0 + result.SelectorNoLimit = &noLimit + } + result.SelectorCriteria = tag.View.SelectorCriteria + result.SelectorProjection = tag.View.SelectorProjection + result.SelectorOrderBy = tag.View.SelectorOrderBy + result.SelectorOffset = tag.View.SelectorOffset + result.SelectorPage = tag.View.SelectorPage + if len(tag.View.SelectorFilterable) > 0 { + result.SelectorFilterable = append([]string(nil), tag.View.SelectorFilterable...) + } + if len(tag.View.SelectorOrderByColumns) > 0 { + result.SelectorOrderByColumns = map[string]string{} + for key, value := range tag.View.SelectorOrderByColumns { + result.SelectorOrderByColumns[key] = value + } + } + if strings.TrimSpace(tag.View.CustomTag) != "" || strings.TrimSpace(field.ViewTypeName) != "" || strings.TrimSpace(field.ViewDest) != "" { + result.Declaration = &ViewDeclaration{ + Tag: strings.TrimSpace(tag.View.CustomTag), + TypeName: strings.TrimSpace(field.ViewTypeName), + Dest: strings.TrimSpace(field.ViewDest), + } + } } result.SQL = tag.SQL.SQL result.SQLURI = tag.SQL.URI result.Summary = tag.SummarySQL.SQL + if tag.View != nil && strings.TrimSpace(tag.View.SummaryURI) != "" { + result.SummaryURL = strings.TrimSpace(tag.View.SummaryURI) + } else { + result.SummaryURL = tag.SummarySQL.URI + } if len(tag.LinkOn) > 0 { result.Relations = append(result.Relations, relationFromTagLinks(field.Name, tag.LinkOn)) } @@ -172,13 +211,18 @@ func normalizeState(field *scan.Field) *State { Name: field.Name, In: &state.Location{}, }, + QuerySelector: strings.TrimSpace(field.QuerySelector), } - if field.StateTag == nil || field.StateTag.Parameter == nil { + if field.StateTag == nil { result.Schema = state.NewSchema(field.Type) return result } pTag := field.StateTag.Parameter + if pTag == nil { + result.Schema = state.NewSchema(field.Type) + return result + } result.Name = firstNonEmpty(pTag.Name, field.Name) result.In = &state.Location{ Kind: state.Kind(strings.ToLower(strings.TrimSpace(pTag.Kind))), @@ -193,14 +237,133 @@ func normalizeState(field *scan.Field) *State { result.URI = pTag.URI result.ErrorStatusCode = pTag.ErrorCode result.ErrorMessage = pTag.ErrorMessage - result.Schema = state.NewSchema(resolveStateType(result, field.Type)) + if typeName := strings.TrimSpace(field.StateTag.TypeName); typeName != "" { + applyStateTypeName(result.Schema, typeName) + } + state.BuildCodec(field.StateTag, &result.Parameter) + state.BuildHandler(field.StateTag, &result.Parameter) + if value, err := field.StateTag.GetValue(result.Schema.Type()); err == nil && value != nil { + result.Value = normalizeStateValue(value) + } if dataType := strings.TrimSpace(pTag.DataType); dataType != "" { result.Schema.DataType = dataType } return result } +func normalizeStateValue(value interface{}) interface{} { + switch actual := value.(type) { + case *string: + if actual == nil { + return nil + } + return *actual + } + return value +} + +func applyStateTypeName(schema *state.Schema, typeName string) { + if schema == nil { + return + } + typeName = strings.TrimSpace(strings.TrimPrefix(typeName, "*")) + if typeName == "" { + return + } + if idx := strings.LastIndex(typeName, "."); idx != -1 { + schema.Package = strings.TrimSpace(typeName[:idx]) + schema.PackagePath = schema.Package + schema.Name = strings.TrimSpace(typeName[idx+1:]) + return + } + schema.Name = typeName +} + +func normalizeComponent(field *scan.Field) *ComponentRoute { + result := &ComponentRoute{ + Path: field.Path, + FieldName: field.Name, + Type: field.Type, + InputType: field.ComponentInputType, + OutputType: field.ComponentOutputType, + InputName: field.ComponentInputName, + OutputName: field.ComponentOutputName, + Name: field.Name, + } + if field.ComponentTag != nil && field.ComponentTag.Component != nil { + tag := field.ComponentTag.Component + if strings.TrimSpace(tag.Name) != "" { + result.Name = strings.TrimSpace(tag.Name) + } + result.RoutePath = strings.TrimSpace(tag.Path) + result.Method = strings.TrimSpace(tag.Method) + result.Connector = strings.TrimSpace(tag.Connector) + result.Marshaller = strings.TrimSpace(tag.Marshaller) + result.Handler = strings.TrimSpace(tag.Handler) + result.ViewName = strings.TrimSpace(tag.View) + result.SourceURL = strings.TrimSpace(tag.Source) + result.SummaryURL = strings.TrimSpace(tag.Summary) + if tag.Report || tag.ReportInput != "" { + result.Report = &dqlshape.ReportDirective{ + Enabled: tag.Report, + Input: strings.TrimSpace(tag.ReportInput), + Dimensions: strings.TrimSpace(tag.ReportDimensions), + Measures: strings.TrimSpace(tag.ReportMeasures), + Filters: strings.TrimSpace(tag.ReportFilters), + OrderBy: strings.TrimSpace(tag.ReportOrderBy), + Limit: strings.TrimSpace(tag.ReportLimit), + Offset: strings.TrimSpace(tag.ReportOffset), + } + } + } + return result +} + +func assignNestedRelationParents(views []*View) { + if len(views) == 0 { + return + } + byPath := map[string]*View{} + for _, item := range views { + if item == nil || strings.TrimSpace(item.Path) == "" { + continue + } + byPath[strings.TrimSpace(item.Path)] = item + } + for _, item := range views { + if item == nil || len(item.Relations) == 0 { + continue + } + parentPath := parentViewPath(item.Path) + if parentPath == "" { + continue + } + parent := byPath[parentPath] + if parent == nil || strings.TrimSpace(parent.Name) == "" { + continue + } + for _, rel := range item.Relations { + if rel == nil || strings.TrimSpace(rel.Parent) != "" { + continue + } + rel.Parent = parent.Name + } + } +} + +func parentViewPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + index := strings.LastIndex(path, ".") + if index == -1 { + return "" + } + return strings.TrimSpace(path[:index]) +} + func resolveStateType(item *State, fallback reflect.Type) reflect.Type { if item.In == nil { return fallback diff --git a/repository/shape/plan/planner_test.go b/repository/shape/plan/planner_test.go index dc0416eb8..483498636 100644 --- a/repository/shape/plan/planner_test.go +++ b/repository/shape/plan/planner_test.go @@ -3,6 +3,7 @@ package plan import ( "context" "embed" + "reflect" "strings" "testing" @@ -13,6 +14,9 @@ import ( outputkeys "github.com/viant/datly/repository/locator/output/keys" "github.com/viant/datly/repository/shape" "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/view/tags" + "github.com/viant/x" + "github.com/viant/xdatly" ) //go:embed testdata/*.sql @@ -35,6 +39,7 @@ type reportSource struct { Job interface{} `parameter:"job,kind=async,in=job"` VName interface{} `parameter:"viewName,kind=meta,in=view.name"` ID int `parameter:"id,kind=query,in=id"` + Route struct{} `component:",path=/v1/api/dev/report,method=GET,connector=dev"` } type relationRow struct { @@ -49,6 +54,68 @@ type relationSourceWithFields struct { Rows []relationRow `view:"rows,table=REPORT" on:"ReportID:rows.report_id=ID:report.id"` } +type viewTypeDestSource struct { + Rows []relationRow `view:"rows,table=REPORT,type=ReportRow,dest=rows.go"` +} + +type typedRouteInput struct { + ID int +} + +type typedRouteOutput struct { + Data []reportRow +} + +type typedRouteSource struct { + Route xdatly.Component[typedRouteInput, typedRouteOutput] `component:",path=/v1/api/dev/report,method=GET"` +} + +type reportRouteSource struct { + Route xdatly.Component[typedRouteInput, typedRouteOutput] `component:",path=/v1/api/dev/report,method=GET,report=true,reportInput=NamedReportInput,reportDimensions=Dims,reportMeasures=Metrics,reportFilters=Predicates,reportOrderBy=Sort,reportLimit=Take,reportOffset=Skip"` +} + +type dynamicRouteInput struct { + Name string +} + +type dynamicRouteOutput struct { + Count int +} + +type namedDynamicRouteInput struct { + Name string `parameter:"name,kind=query,in=name"` +} + +type namedDynamicRouteOutput struct { + Count int `parameter:"count,kind=output,in=view"` +} + +type taggedComponentStateSource struct { + Auth interface{} `parameter:",kind=component,in=GET:/v1/api/dev/auth" typeName:"github.com/acme/auth.UserAclOutput"` +} + +type constStateSource struct { + Product string `parameter:",kind=const,in=Product" value:"PRODUCT" internal:"true"` +} + +type codecStateSource struct { + Jwt string `parameter:",kind=header,in=Authorization,errorCode=401" codec:"JwtClaim"` + Run string `parameter:",kind=body,in=run" handler:"Exec"` +} + +type selectorHolderSource struct { + Route xdatly.Component[typedRouteInput, typedRouteOutput] `component:",path=/v1/api/dev/report,method=GET"` + ViewSelect struct { + Fields []string `parameter:"fields,kind=query,in=_fields"` + Page int `parameter:"page,kind=query,in=_page"` + } `querySelector:"rows"` +} + +type summaryViewSource struct { + embeddedFS + Rows []relationRow `view:"rows,table=REPORT,summaryURI=testdata/report_summary.sql" sql:"uri=testdata/report.sql"` +} + func TestPlanner_Plan(t *testing.T) { scanner := scan.New() scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) @@ -88,6 +155,45 @@ func TestPlanner_Plan(t *testing.T) { require.NotNil(t, stateByPath["id"]) assert.Equal(t, "query", stateByPath["id"].KindString()) assert.Equal(t, "id", stateByPath["id"].InName()) + require.Len(t, result.Components, 1) + assert.Equal(t, "Route", result.Components[0].FieldName) + assert.Equal(t, "/v1/api/dev/report", result.Components[0].RoutePath) + assert.Equal(t, "GET", result.Components[0].Method) + assert.Equal(t, "dev", result.Components[0].Connector) +} + +func TestPlanner_Plan_QuerySelectorHolder(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &selectorHolderSource{}}) + require.NoError(t, err) + + planned, err := New().Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + + byName := map[string]*State{} + for _, item := range result.States { + byName[item.Name] = item + } + require.NotNil(t, byName["fields"]) + assert.Equal(t, "rows", byName["fields"].QuerySelector) + require.NotNil(t, byName["page"]) + assert.Equal(t, "rows", byName["page"].QuerySelector) +} + +func TestPlanner_Plan_ViewSummaryURI(t *testing.T) { + scanned, err := scan.New().Scan(context.Background(), &shape.Source{Struct: &summaryViewSource{}}) + require.NoError(t, err) + + planned, err := New().Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Views, 1) + assert.Equal(t, "testdata/report_summary.sql", result.Views[0].SummaryURL) } func TestPlanner_Plan_LinkOnProducesStructuredRelations(t *testing.T) { @@ -113,6 +219,212 @@ func TestPlanner_Plan_LinkOnProducesStructuredRelations(t *testing.T) { assert.Equal(t, "id", relation.On[0].RefColumn) } +func TestPlanner_Plan_ComponentHolderTypes(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &typedRouteSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Components, 1) + assert.Equal(t, reflect.TypeOf(typedRouteInput{}), result.Components[0].InputType) + assert.Equal(t, reflect.TypeOf(typedRouteOutput{}), result.Components[0].OutputType) + assert.Empty(t, result.Components[0].InputName) + assert.Empty(t, result.Components[0].OutputName) +} + +func TestPlanner_Plan_ComponentReportTags(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportRouteSource{}}) + require.NoError(t, err) + + planned, err := New().Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Components, 1) + require.NotNil(t, result.Components[0].Report) + assert.True(t, result.Components[0].Report.Enabled) + assert.Equal(t, "NamedReportInput", result.Components[0].Report.Input) + assert.Equal(t, "Dims", result.Components[0].Report.Dimensions) + assert.Equal(t, "Metrics", result.Components[0].Report.Measures) + assert.Equal(t, "Predicates", result.Components[0].Report.Filters) + assert.Equal(t, "Sort", result.Components[0].Report.OrderBy) + assert.Equal(t, "Take", result.Components[0].Report.Limit) + assert.Equal(t, "Skip", result.Components[0].Report.Offset) +} + +func TestPlanner_Plan_DynamicComponentHolderTypes(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET"` + }{ + Route: xdatly.Component[any, any]{ + Inout: dynamicRouteInput{}, + Output: dynamicRouteOutput{}, + }, + }}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Components, 1) + assert.Equal(t, reflect.TypeOf(dynamicRouteInput{}), result.Components[0].InputType) + assert.Equal(t, reflect.TypeOf(dynamicRouteOutput{}), result.Components[0].OutputType) +} + +func TestPlanner_Plan_DynamicComponentHolderExplicitNames(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET,input=ReportInput,output=ReportOutput"` + }{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Components, 1) + assert.Nil(t, result.Components[0].InputType) + assert.Nil(t, result.Components[0].OutputType) + assert.Equal(t, "ReportInput", result.Components[0].InputName) + assert.Equal(t, "ReportOutput", result.Components[0].OutputName) +} + +func TestPlanner_Plan_PreservesConstValueTag(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &constStateSource{}}) + require.NoError(t, err) + + planned, err := New().Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.States, 1) + require.NotNil(t, result.States[0].In) + assert.Equal(t, "const", result.States[0].KindString()) + assert.Equal(t, "Product", result.States[0].InName()) + assert.Equal(t, "PRODUCT", result.States[0].Value) +} + +func TestPlanner_Plan_PreservesCodecAndHandlerTags(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &codecStateSource{}}) + require.NoError(t, err) + + planned, err := New().Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.States, 2) + stateByName := map[string]*State{} + for _, item := range result.States { + stateByName[item.Name] = item + } + require.NotNil(t, stateByName["Jwt"]) + require.NotNil(t, stateByName["Run"]) + if stateByName["Jwt"].Output == nil || stateByName["Jwt"].Output.Name != "JwtClaim" { + t.Fatalf("expected Jwt codec to be preserved, got %#v", stateByName["Jwt"].Output) + } + if stateByName["Run"].Handler == nil || stateByName["Run"].Handler.Name != "Exec" { + t.Fatalf("expected Run handler to be preserved, got %#v", stateByName["Run"].Handler) + } +} + +func TestPlanner_Plan_DynamicComponentHolderExplicitNamesFromRegistry(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(namedDynamicRouteInput{}), x.WithPkgPath("github.com/viant/datly/repository/shape/plan"), x.WithName("ReportInput"))) + registry.Register(x.NewType(reflect.TypeOf(namedDynamicRouteOutput{}), x.WithPkgPath("github.com/viant/datly/repository/shape/plan"), x.WithName("ReportOutput"))) + + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{ + Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET,input=ReportInput,output=ReportOutput"` + }{}, + TypeRegistry: registry, + }) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Components, 1) + assert.Equal(t, reflect.TypeOf(namedDynamicRouteInput{}), result.Components[0].InputType) + assert.Equal(t, reflect.TypeOf(namedDynamicRouteOutput{}), result.Components[0].OutputType) + assert.Equal(t, "ReportInput", result.Components[0].InputName) + assert.Equal(t, "ReportOutput", result.Components[0].OutputName) +} + +func TestPlanner_Plan_StateTypeNameOverridesInterfaceType(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &taggedComponentStateSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.States, 1) + assert.Equal(t, "github.com/acme/auth", result.States[0].Schema.Package) + assert.Equal(t, "github.com/acme/auth", result.States[0].Schema.PackagePath) + assert.Equal(t, "UserAclOutput", result.States[0].Schema.Name) +} + +func TestPlanner_Plan_AssignsNestedRelationParent(t *testing.T) { + scanned := &shape.ScanResult{ + Source: &shape.Source{Name: "nested"}, + Descriptors: &scan.Result{ + RootType: reflect.TypeOf(struct{}{}), + ViewFields: []*scan.Field{ + { + Path: "Route.Output.Data", + Name: "Data", + Type: reflect.TypeOf([]struct{}{}), + ViewTag: &tags.Tag{ + View: &tags.View{Name: "vendor"}, + SQL: tags.NewViewSQL("", "vendor.sql"), + }, + }, + { + Path: "Route.Output.Data.Products", + Name: "Products", + Type: reflect.TypeOf([]struct{}{}), + ViewTag: &tags.Tag{ + View: &tags.View{Table: "PRODUCT"}, + SQL: tags.NewViewSQL("", "products.sql"), + LinkOn: []string{"Id:ID=VendorId:VENDOR_ID"}, + }, + }, + }, + }, + } + result, err := New().Plan(context.Background(), scanned) + require.NoError(t, err) + planned, ok := result.Plan.(*Result) + require.True(t, ok) + require.Len(t, planned.Views, 2) + require.Len(t, planned.Views[1].Relations, 1) + require.Equal(t, "vendor", planned.Views[1].Relations[0].Parent) +} + func TestPlanner_Plan_LinkOnPreservesFieldSelectors(t *testing.T) { scanner := scan.New() scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &relationSourceWithFields{}}) @@ -138,6 +450,24 @@ func TestPlanner_Plan_LinkOnPreservesFieldSelectors(t *testing.T) { assert.Equal(t, "id", relation.On[0].RefColumn) } +func TestPlanner_Plan_ViewTypeDestDeclaration(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &viewTypeDestSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + result, ok := ResultFrom(planned) + require.True(t, ok) + require.Len(t, result.Views, 1) + viewPlan := result.Views[0] + require.NotNil(t, viewPlan.Declaration) + assert.Equal(t, "ReportRow", viewPlan.Declaration.TypeName) + assert.Equal(t, "rows.go", viewPlan.Declaration.Dest) +} + // stubScanSpec is a non-scan-Result implementation of shape.ScanSpec used to // verify that Plan() returns an error when given an unexpected descriptor type. type stubScanSpec struct{} diff --git a/repository/shape/plan/testdata/report_summary.sql b/repository/shape/plan/testdata/report_summary.sql new file mode 100644 index 000000000..3c601cdc2 --- /dev/null +++ b/repository/shape/plan/testdata/report_summary.sql @@ -0,0 +1 @@ +SELECT COUNT(*) AS CNT FROM REPORT diff --git a/repository/shape/scan/component_contract.go b/repository/shape/scan/component_contract.go new file mode 100644 index 000000000..c26cc3a78 --- /dev/null +++ b/repository/shape/scan/component_contract.go @@ -0,0 +1,229 @@ +package scan + +import ( + "fmt" + "path" + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/componenttag" + "github.com/viant/datly/repository/shape/typectx" +) + +const ( + xdatlyComponentPkg = "github.com/viant/xdatly" + xdatlyComponentName = "Component[" +) + +type componentContract struct { + InputType reflect.Type + OutputType reflect.Type + InputName string + OutputName string + UsesDynamic bool +} + +func resolveComponentContract(source *shape.Source, fieldType reflect.Type, value reflect.Value, tag *componenttag.Tag) (*componentContract, error) { + contract := &componentContract{} + if tag != nil && tag.Component != nil { + contract.InputName = strings.TrimSpace(tag.Component.Input) + contract.OutputName = strings.TrimSpace(tag.Component.Output) + } + if provider := componentContractProvider(value); provider != nil { + contract.InputType = provider.ComponentInputType() + contract.OutputType = provider.ComponentOutputType() + } + typedInput, typedOutput, dynamic := inspectXDatlyComponent(fieldType, value) + isStructuredHolder := providerDefined(contract) || isXDatlyComponentType(fieldType) + if contract.InputType == nil { + contract.InputType = typedInput + } + if contract.OutputType == nil { + contract.OutputType = typedOutput + } + if contract.InputType == nil && contract.InputName != "" { + contract.InputType = resolveNamedComponentContractType(source, contract.InputName) + } + if contract.OutputType == nil && contract.OutputName != "" { + contract.OutputType = resolveNamedComponentContractType(source, contract.OutputName) + } + contract.UsesDynamic = dynamic + if !isStructuredHolder && contract.InputName == "" && contract.OutputName == "" { + return contract, nil + } + if contract.UsesDynamic && + contract.InputType == nil && + contract.OutputType == nil && + contract.InputName == "" && + contract.OutputName == "" { + return nil, fmt.Errorf("dynamic component holder requires explicit input/output tag names or initialized Inout/Output values") + } + if contract.InputType == nil && contract.InputName == "" { + return nil, fmt.Errorf("component input contract type is unresolved") + } + if contract.OutputType == nil && contract.OutputName == "" { + return nil, fmt.Errorf("component output contract type is unresolved") + } + return contract, nil +} + +func resolveNamedComponentContractType(source *shape.Source, typeName string) reflect.Type { + typeName = strings.TrimSpace(typeName) + if source == nil || typeName == "" { + return nil + } + registry := source.EnsureTypeRegistry() + if registry == nil { + return nil + } + resolver := typectx.NewResolver(registry, componentTypeContext(source)) + resolved, err := resolver.Resolve(typeName) + if err != nil || resolved == "" { + return nil + } + lookup := registry.Lookup(resolved) + if lookup == nil || lookup.Type == nil { + return nil + } + return unwrapComponentType(lookup.Type) +} + +func componentTypeContext(source *shape.Source) *typectx.Context { + if source == nil { + return nil + } + rootType, err := source.ResolveRootType() + if err != nil || rootType == nil { + return nil + } + pkgPath := strings.TrimSpace(rootType.PkgPath()) + if pkgPath == "" { + return nil + } + return &typectx.Context{ + DefaultPackage: pkgPath, + PackagePath: pkgPath, + PackageName: path.Base(pkgPath), + } +} + +func providerDefined(contract *componentContract) bool { + return contract != nil && (contract.InputType != nil || contract.OutputType != nil) +} + +type typedComponentContract interface { + ComponentInputType() reflect.Type + ComponentOutputType() reflect.Type +} + +func componentContractProvider(value reflect.Value) typedComponentContract { + if !value.IsValid() { + return nil + } + if value.CanInterface() { + if provider, ok := value.Interface().(typedComponentContract); ok { + return provider + } + } + for value.IsValid() && value.Kind() == reflect.Interface { + if value.IsNil() { + return nil + } + value = value.Elem() + } + if value.CanInterface() { + if provider, ok := value.Interface().(typedComponentContract); ok { + return provider + } + } + for value.IsValid() && value.Kind() == reflect.Ptr { + if value.IsNil() { + return nil + } + value = value.Elem() + if !value.CanInterface() { + continue + } + if provider, ok := value.Interface().(typedComponentContract); ok { + return provider + } + } + return nil +} + +func componentFieldValue(holder reflect.Value, fieldName string) reflect.Value { + if !holder.IsValid() { + return reflect.Value{} + } + for holder.IsValid() && holder.Kind() == reflect.Ptr { + if holder.IsNil() { + return reflect.Value{} + } + holder = holder.Elem() + } + if !holder.IsValid() || holder.Kind() != reflect.Struct { + return reflect.Value{} + } + field := holder.FieldByName(fieldName) + if !field.IsValid() { + return reflect.Value{} + } + return field +} + +func concreteComponentFieldType(fallback reflect.Type, holder reflect.Value, fieldName string) (reflect.Type, bool) { + if fallback != nil && !(fallback.Kind() == reflect.Interface && fallback.NumMethod() == 0) { + return fallback, false + } + field := componentFieldValue(holder, fieldName) + if !field.IsValid() { + return nil, true + } + for field.IsValid() && field.Kind() == reflect.Interface { + if field.IsNil() { + return nil, true + } + field = field.Elem() + } + for field.IsValid() && field.Kind() == reflect.Ptr { + if field.IsNil() { + return nil, true + } + field = field.Elem() + } + if !field.IsValid() { + return nil, true + } + return field.Type(), true +} + +func inspectXDatlyComponent(rType reflect.Type, value reflect.Value) (reflect.Type, reflect.Type, bool) { + rType = unwrapComponentType(rType) + if !isXDatlyComponentType(rType) { + return nil, nil, false + } + inoutField, ok := rType.FieldByName("Inout") + if !ok { + return nil, nil, false + } + outputField, ok := rType.FieldByName("Output") + if !ok { + return nil, nil, false + } + inputType, inputDynamic := concreteComponentFieldType(inoutField.Type, value, "Inout") + outputType, outputDynamic := concreteComponentFieldType(outputField.Type, value, "Output") + return inputType, outputType, inputDynamic || outputDynamic +} + +func isXDatlyComponentType(rType reflect.Type) bool { + rType = unwrapComponentType(rType) + return rType != nil && rType.Kind() == reflect.Struct && rType.PkgPath() == xdatlyComponentPkg && strings.HasPrefix(rType.Name(), xdatlyComponentName) +} + +func unwrapComponentType(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} diff --git a/repository/shape/scan/model.go b/repository/shape/scan/model.go index 357299250..59e968a51 100644 --- a/repository/shape/scan/model.go +++ b/repository/shape/scan/model.go @@ -4,30 +4,41 @@ import ( "embed" "reflect" + "github.com/viant/datly/repository/shape/componenttag" "github.com/viant/datly/view/tags" ) // Result holds scan output produced from a struct source. type Result struct { - RootType reflect.Type - EmbedFS *embed.FS - Fields []*Field - ByPath map[string]*Field - ViewFields []*Field - StateFields []*Field + RootType reflect.Type + EmbedFS *embed.FS + Fields []*Field + ByPath map[string]*Field + ViewFields []*Field + StateFields []*Field + ComponentFields []*Field } // Field describes one scanned struct field. type Field struct { - Path string - Name string - Index []int - Type reflect.Type - Tag reflect.StructTag - Anonymous bool + Path string + Name string + Index []int + Type reflect.Type + QuerySelector string + ComponentInputType reflect.Type + ComponentOutputType reflect.Type + ComponentInputName string + ComponentOutputName string + Tag reflect.StructTag + Anonymous bool + ViewTypeName string + ViewDest string - HasViewTag bool - HasStateTag bool - ViewTag *tags.Tag - StateTag *tags.Tag + HasViewTag bool + HasStateTag bool + HasComponentTag bool + ViewTag *tags.Tag + StateTag *tags.Tag + ComponentTag *componenttag.Tag } diff --git a/repository/shape/scan/scanner.go b/repository/shape/scan/scanner.go index 255f9cd4b..c4463380f 100644 --- a/repository/shape/scan/scanner.go +++ b/repository/shape/scan/scanner.go @@ -2,13 +2,19 @@ package scan import ( "context" + "embed" "fmt" + "os" + "path/filepath" "reflect" "strings" + afsembed "github.com/viant/afs/embed" "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/componenttag" "github.com/viant/datly/view/state" "github.com/viant/datly/view/tags" + taglytags "github.com/viant/tagly/tags" ) // StructScanner scans arbitrary struct types and extracts Datly-relevant tags. @@ -35,13 +41,15 @@ func (s *StructScanner) Scan(ctx context.Context, source *shape.Source, _ ...sha } embedder := resolveEmbedder(source) + baseDir := source.BaseDir() + rootValue := resolveRootValue(source) result := &Result{ RootType: root, EmbedFS: embedder.EmbedFS(), ByPath: map[string]*Field{}, } - if err = s.scanStruct(root, "", nil, embedder, result, map[reflect.Type]bool{}); err != nil { + if err = s.scanStruct(source, root, rootValue, "", nil, "", embedder, baseDir, result, map[reflect.Type]bool{}); err != nil { return nil, err } @@ -85,11 +93,32 @@ func resolveEmbedder(source *shape.Source) *state.FSEmbedder { return embedder } +func resolveRootValue(source *shape.Source) reflect.Value { + if source == nil || source.Struct == nil { + return reflect.Value{} + } + value := reflect.ValueOf(source.Struct) + for value.IsValid() && value.Kind() == reflect.Ptr { + if value.IsNil() { + return reflect.Value{} + } + value = value.Elem() + } + if !value.IsValid() || value.Kind() != reflect.Struct { + return reflect.Value{} + } + return value +} + func (s *StructScanner) scanStruct( + source *shape.Source, rType reflect.Type, + rootValue reflect.Value, prefix string, indexPrefix []int, + inheritedQuerySelector string, embedder *state.FSEmbedder, + baseDir string, result *Result, visited map[reflect.Type]bool, ) error { @@ -108,16 +137,22 @@ func (s *StructScanner) scanStruct( combinedIndex := append(append([]int{}, indexPrefix...), field.Index...) descriptor := &Field{ - Path: path, - Name: field.Name, - Index: combinedIndex, - Type: field.Type, - Tag: field.Tag, - Anonymous: field.Anonymous, + Path: path, + Name: field.Name, + Index: combinedIndex, + Type: field.Type, + QuerySelector: inheritedQuerySelector, + Tag: field.Tag, + Anonymous: field.Anonymous, + } + if querySelector := tags.ParseQuerySelector(field.Tag.Get(tags.QuerySelectorTag)); querySelector != "" { + descriptor.QuerySelector = querySelector } + fieldFS := parseFS(field.Tag, embedder.EmbedFS(), baseDir) if hasAny(field.Tag, tags.ViewTag, tags.SQLTag, tags.SQLSummaryTag, tags.LinkOnTag) { - parsed, err := tags.ParseViewTags(field.Tag, embedder.EmbedFS()) + descriptor.ViewTypeName, descriptor.ViewDest = parseShapeViewHints(field.Tag) + parsed, err := tags.ParseViewTags(field.Tag, fieldFS) if err != nil { return fmt.Errorf("shape scan: failed to parse view tags on %s: %w", path, err) } @@ -126,8 +161,8 @@ func (s *StructScanner) scanStruct( result.ViewFields = append(result.ViewFields, descriptor) } - if hasAny(field.Tag, tags.ParameterTag, tags.SQLTag, tags.PredicateTag, tags.CodecTag, tags.HandlerTag) { - parsed, err := tags.ParseStateTags(field.Tag, embedder.EmbedFS()) + if hasAny(field.Tag, tags.ParameterTag, tags.PredicateTag, tags.CodecTag, tags.HandlerTag) { + parsed, err := tags.ParseStateTags(field.Tag, fieldFS) if err != nil { return fmt.Errorf("shape scan: failed to parse state tags on %s: %w", path, err) } @@ -136,15 +171,36 @@ func (s *StructScanner) scanStruct( result.StateFields = append(result.StateFields, descriptor) } + if hasAny(field.Tag, componenttag.TagName) { + parsed, err := componenttag.Parse(field.Tag) + if err != nil { + return fmt.Errorf("shape scan: failed to parse component tags on %s: %w", path, err) + } + descriptor.HasComponentTag = true + descriptor.ComponentTag = parsed + fieldValue := fieldValueByIndex(rootValue, combinedIndex) + contract, err := resolveComponentContract(source, field.Type, fieldValue, parsed) + if err != nil { + return fmt.Errorf("shape scan: failed to resolve component contract on %s: %w", path, err) + } + if contract != nil { + descriptor.ComponentInputType = contract.InputType + descriptor.ComponentOutputType = contract.OutputType + descriptor.ComponentInputName = contract.InputName + descriptor.ComponentOutputName = contract.OutputName + if err := s.scanComponentContracts(source, path, fieldValue, contract, baseDir, result, visited); err != nil { + return err + } + } + result.ComponentFields = append(result.ComponentFields, descriptor) + } + result.Fields = append(result.Fields, descriptor) result.ByPath[path] = descriptor - nextType := field.Type - for nextType.Kind() == reflect.Ptr { - nextType = nextType.Elem() - } - if field.Anonymous && nextType.Kind() == reflect.Struct && !isStdlib(nextType.PkgPath()) { - if err := s.scanStruct(nextType, path, combinedIndex, embedder, result, visited); err != nil { + nextType := nestedStructType(field.Type) + if nextType != nil && shouldRecurseIntoField(field, descriptor, nextType) { + if err := s.scanStruct(source, nextType, rootValue, path, combinedIndex, descriptor.QuerySelector, embedder, baseDir, result, visited); err != nil { return err } } @@ -152,6 +208,95 @@ func (s *StructScanner) scanStruct( return nil } +func nestedStructType(rType reflect.Type) reflect.Type { + for rType != nil { + switch rType.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Array: + rType = rType.Elem() + default: + if rType.Kind() == reflect.Struct { + return rType + } + return nil + } + } + return nil +} + +func shouldRecurseIntoField(field reflect.StructField, descriptor *Field, nextType reflect.Type) bool { + if nextType == nil || nextType.Kind() != reflect.Struct { + return false + } + if field.Anonymous { + return !isStdlib(nextType.PkgPath()) + } + if descriptor != nil && descriptor.HasViewTag { + // Source-reconstructed and StructOf-based semantic view structs often have no package path. + // They still need recursive scanning so nested relation views are preserved. + return true + } + if descriptor != nil && strings.TrimSpace(descriptor.QuerySelector) != "" { + return true + } + return false +} + +func (s *StructScanner) scanComponentContracts( + source *shape.Source, + prefix string, + fieldValue reflect.Value, + contract *componentContract, + baseDir string, + result *Result, + visited map[reflect.Type]bool, +) error { + if contract == nil { + return nil + } + if contract.InputType != nil { + embedder := state.NewFSEmbedder(nil) + embedder.SetType(contract.InputType) + if err := s.scanStruct(source, contractInputRoot(contract.InputType), componentFieldValue(fieldValue, "Inout"), prefix+".Inout", nil, "", embedder, baseDir, result, visited); err != nil { + return err + } + } + if contract.OutputType != nil { + embedder := state.NewFSEmbedder(nil) + embedder.SetType(contract.OutputType) + if err := s.scanStruct(source, contractInputRoot(contract.OutputType), componentFieldValue(fieldValue, "Output"), prefix+".Output", nil, "", embedder, baseDir, result, visited); err != nil { + return err + } + } + return nil +} + +func contractInputRoot(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} + +func fieldValueByIndex(rootValue reflect.Value, index []int) reflect.Value { + if !rootValue.IsValid() || len(index) == 0 { + return reflect.Value{} + } + current := rootValue + for _, idx := range index { + for current.IsValid() && current.Kind() == reflect.Ptr { + if current.IsNil() { + return reflect.Value{} + } + current = current.Elem() + } + if !current.IsValid() || current.Kind() != reflect.Struct || idx < 0 || idx >= current.NumField() { + return reflect.Value{} + } + current = current.Field(idx) + } + return current +} + func hasAny(tag reflect.StructTag, names ...string) bool { for _, name := range names { if _, ok := tag.Lookup(name); ok { @@ -161,9 +306,79 @@ func hasAny(tag reflect.StructTag, names ...string) bool { return false } +func parseFS(tag reflect.StructTag, existing *embed.FS, baseDir string) *embed.FS { + baseDir = strings.TrimSpace(baseDir) + if baseDir == "" { + return existing + } + uris := sqlURIs(tag) + if len(uris) == 0 { + return existing + } + holder := afsembed.NewHolder() + if existing != nil { + holder.AddFs(existing, ".") + } + added := 0 + for _, URI := range uris { + if URI == "" || filepath.IsAbs(URI) || strings.Contains(URI, "://") { + continue + } + absPath := filepath.Join(baseDir, filepath.FromSlash(URI)) + data, err := os.ReadFile(absPath) + if err != nil { + continue + } + holder.Add(filepath.ToSlash(URI), string(data)) + added++ + } + if added == 0 { + return existing + } + return holder.EmbedFs() +} + +func sqlURIs(tag reflect.StructTag) []string { + var result []string + appendURI := func(tagName string) { + value := strings.TrimSpace(tag.Get(tagName)) + if !strings.HasPrefix(value, "uri=") { + return + } + URI := strings.TrimSpace(value[4:]) + if URI != "" { + result = append(result, URI) + } + } + appendURI(tags.SQLTag) + appendURI(tags.SQLSummaryTag) + return result +} + func isStdlib(pkg string) bool { if pkg == "" { return true } return !strings.Contains(pkg, ".") } + +func parseShapeViewHints(tag reflect.StructTag) (string, string) { + raw, ok := tag.Lookup(tags.ViewTag) + if !ok { + return "", "" + } + _, values := taglytags.Values(raw).Name() + var typeName, dest string + _ = values.MatchPairs(func(key, value string) error { + switch strings.ToLower(strings.TrimSpace(key)) { + case "type": + typeName = strings.TrimSpace(value) + case "typename": + typeName = strings.TrimSpace(value) + case "dest": + dest = strings.TrimSpace(value) + } + return nil + }) + return typeName, dest +} diff --git a/repository/shape/scan/scanner_test.go b/repository/shape/scan/scanner_test.go index bf57d5cec..fc0bbda71 100644 --- a/repository/shape/scan/scanner_test.go +++ b/repository/shape/scan/scanner_test.go @@ -3,6 +3,8 @@ package scan import ( "context" "embed" + "os" + "path/filepath" "reflect" "testing" @@ -10,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/viant/datly/repository/shape" "github.com/viant/x" + "github.com/viant/xdatly" ) //go:embed testdata/*.sql @@ -28,8 +31,49 @@ type reportRow struct { type reportSource struct { embeddedFS - Rows []reportRow `view:"rows,table=REPORT,connector=dev" sql:"uri=testdata/report.sql"` - ID int `parameter:"id,kind=query,in=id"` + Rows []reportRow `view:"rows,table=REPORT,connector=dev,type=ReportRow,dest=rows.go" sql:"uri=testdata/report.sql"` + ID int `parameter:"id,kind=query,in=id"` + Route struct{} `component:",path=/v1/api/dev/report,method=GET,connector=dev"` +} + +type reportInput struct { + ID int +} + +type reportOutput struct { + Data []reportRow +} + +type typedComponentSource struct { + Route xdatly.Component[reportInput, reportOutput] `component:",path=/v1/api/dev/report,method=GET"` +} + +type reportEnabledSource struct { + Route xdatly.Component[reportInput, reportOutput] `component:",path=/v1/api/dev/report,method=GET,report=true,reportInput=NamedReportInput,reportDimensions=Dims,reportMeasures=Metrics,reportFilters=Predicates,reportOrderBy=Sort,reportLimit=Take,reportOffset=Skip"` +} + +type dynamicReportInput struct { + Name string +} + +type dynamicReportOutput struct { + Count int +} + +type namedReportInput struct { + Name string `parameter:"name,kind=query,in=name"` +} + +type namedReportOutput struct { + Data []reportRow `parameter:"data,kind=output,in=view"` +} + +type selectorHolderSource struct { + Route xdatly.Component[reportInput, reportOutput] `component:",path=/v1/api/dev/report,method=GET"` + ViewSelect struct { + Fields []string `parameter:"fields,kind=query,in=_fields"` + Page int `parameter:"page,kind=query,in=_page"` + } `querySelector:"rows"` } func TestStructScanner_Scan(t *testing.T) { @@ -49,6 +93,8 @@ func TestStructScanner_Scan(t *testing.T) { require.True(t, rows.HasViewTag) require.NotNil(t, rows.ViewTag) assert.Equal(t, "rows", rows.ViewTag.View.Name) + assert.Equal(t, "ReportRow", rows.ViewTypeName) + assert.Equal(t, "rows.go", rows.ViewDest) assert.Contains(t, rows.ViewTag.SQL.SQL, "SELECT ID, NAME FROM REPORT") idField := descriptors.ByPath["ID"] @@ -59,6 +105,15 @@ func TestStructScanner_Scan(t *testing.T) { assert.Equal(t, "id", idField.StateTag.Parameter.Name) assert.Equal(t, "query", idField.StateTag.Parameter.Kind) assert.Equal(t, "id", idField.StateTag.Parameter.In) + + route := descriptors.ByPath["Route"] + require.NotNil(t, route) + require.True(t, route.HasComponentTag) + require.NotNil(t, route.ComponentTag) + require.NotNil(t, route.ComponentTag.Component) + assert.Equal(t, "/v1/api/dev/report", route.ComponentTag.Component.Path) + assert.Equal(t, "GET", route.ComponentTag.Component.Method) + assert.Equal(t, "dev", route.ComponentTag.Component.Connector) } func TestStructScanner_Scan_InvalidSource(t *testing.T) { @@ -68,6 +123,133 @@ func TestStructScanner_Scan_InvalidSource(t *testing.T) { assert.Contains(t, err.Error(), "expected struct") } +func TestStructScanner_Scan_ComponentHolderTypes(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &typedComponentSource{}}) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + require.Len(t, descriptors.ComponentFields, 1) + route := descriptors.ComponentFields[0] + require.NotNil(t, route) + assert.Equal(t, reflect.TypeOf(reportInput{}), route.ComponentInputType) + assert.Equal(t, reflect.TypeOf(reportOutput{}), route.ComponentOutputType) + assert.Empty(t, route.ComponentInputName) + assert.Empty(t, route.ComponentOutputName) +} + +func TestStructScanner_Scan_ComponentReportTags(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportEnabledSource{}}) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + require.Len(t, descriptors.ComponentFields, 1) + route := descriptors.ComponentFields[0] + require.NotNil(t, route) + require.NotNil(t, route.ComponentTag) + require.NotNil(t, route.ComponentTag.Component) + assert.True(t, route.ComponentTag.Component.Report) + assert.Equal(t, "NamedReportInput", route.ComponentTag.Component.ReportInput) + assert.Equal(t, "Dims", route.ComponentTag.Component.ReportDimensions) + assert.Equal(t, "Metrics", route.ComponentTag.Component.ReportMeasures) + assert.Equal(t, "Predicates", route.ComponentTag.Component.ReportFilters) + assert.Equal(t, "Sort", route.ComponentTag.Component.ReportOrderBy) + assert.Equal(t, "Take", route.ComponentTag.Component.ReportLimit) + assert.Equal(t, "Skip", route.ComponentTag.Component.ReportOffset) +} + +func TestStructScanner_Scan_QuerySelectorHolder(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &selectorHolderSource{}}) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + + fields := descriptors.ByPath["ViewSelect.Fields"] + require.NotNil(t, fields) + assert.Equal(t, "rows", fields.QuerySelector) + + page := descriptors.ByPath["ViewSelect.Page"] + require.NotNil(t, page) + assert.Equal(t, "rows", page.QuerySelector) +} + +func TestStructScanner_Scan_DynamicComponentHolderTypes(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET"` + }{ + Route: xdatly.Component[any, any]{ + Inout: dynamicReportInput{}, + Output: dynamicReportOutput{}, + }, + }}) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + require.Len(t, descriptors.ComponentFields, 1) + route := descriptors.ComponentFields[0] + require.NotNil(t, route) + assert.Equal(t, reflect.TypeOf(dynamicReportInput{}), route.ComponentInputType) + assert.Equal(t, reflect.TypeOf(dynamicReportOutput{}), route.ComponentOutputType) +} + +func TestStructScanner_Scan_DynamicComponentHolderTypesWithExplicitNames(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET,input=ReportInput,output=ReportOutput"` + }{}}) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + require.Len(t, descriptors.ComponentFields, 1) + route := descriptors.ComponentFields[0] + require.NotNil(t, route) + assert.Nil(t, route.ComponentInputType) + assert.Nil(t, route.ComponentOutputType) + assert.Equal(t, "ReportInput", route.ComponentInputName) + assert.Equal(t, "ReportOutput", route.ComponentOutputName) +} + +func TestStructScanner_Scan_DynamicComponentHolderTypesWithExplicitNamesFromRegistry(t *testing.T) { + scanner := New() + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(namedReportInput{}), x.WithPkgPath("github.com/viant/datly/repository/shape/scan"), x.WithName("ReportInput"))) + registry.Register(x.NewType(reflect.TypeOf(namedReportOutput{}), x.WithPkgPath("github.com/viant/datly/repository/shape/scan"), x.WithName("ReportOutput"))) + result, err := scanner.Scan(context.Background(), &shape.Source{ + Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET,input=ReportInput,output=ReportOutput"` + }{}, + TypeRegistry: registry, + }) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + require.Len(t, descriptors.ComponentFields, 1) + route := descriptors.ComponentFields[0] + require.NotNil(t, route) + assert.Equal(t, reflect.TypeOf(namedReportInput{}), route.ComponentInputType) + assert.Equal(t, reflect.TypeOf(namedReportOutput{}), route.ComponentOutputType) + assert.Equal(t, "ReportInput", route.ComponentInputName) + assert.Equal(t, "ReportOutput", route.ComponentOutputName) +} + +func TestStructScanner_Scan_DynamicComponentHolderTypesRequireContract(t *testing.T) { + scanner := New() + _, err := scanner.Scan(context.Background(), &shape.Source{Struct: &struct { + Route xdatly.Component[any, any] `component:",path=/v1/api/dev/report,method=GET"` + }{}}) + require.Error(t, err) + assert.Contains(t, err.Error(), "dynamic component holder requires explicit input/output tag names or initialized Inout/Output values") +} + func TestStructScanner_Scan_WithRegistryType(t *testing.T) { scanner := New() registry := x.NewRegistry() @@ -81,3 +263,65 @@ func TestStructScanner_Scan_WithRegistryType(t *testing.T) { require.True(t, ok) assert.Equal(t, reflect.TypeOf(reportSource{}), descriptors.RootType) } + +func TestStructScanner_Scan_UsesSourceBaseDirForRelativeSQL(t *testing.T) { + scanner := New() + baseDir := t.TempDir() + sqlPath := filepath.Join(baseDir, "routes", "report.sql") + require.NoError(t, os.MkdirAll(filepath.Dir(sqlPath), 0o755)) + require.NoError(t, os.WriteFile(sqlPath, []byte("SELECT ID FROM REPORT"), 0o644)) + + type reportView struct { + Data []reportRow `view:"rows" sql:"uri=routes/report.sql"` + } + + result, err := scanner.Scan(context.Background(), &shape.Source{ + Type: reflect.TypeOf(reportView{}), + Path: filepath.Join(baseDir, "router.go"), + }) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + viewField := descriptors.ByPath["Data"] + require.NotNil(t, viewField) + require.NotNil(t, viewField.ViewTag) + assert.Equal(t, "SELECT ID FROM REPORT", string(viewField.ViewTag.SQL.SQL)) + assert.Equal(t, "routes/report.sql", string(viewField.ViewTag.SQL.URI)) +} + +func TestStructScanner_Scan_RecursesIntoViewTaggedStructFields(t *testing.T) { + type vendorProduct struct { + ID int `sqlx:"ID"` + VendorID int `sqlx:"VENDOR_ID"` + } + type vendorRow struct { + ID int `sqlx:"ID"` + Products []*vendorProduct `view:",table=PRODUCT" on:"Id:ID=VendorId:VENDOR_ID" sql:"uri=testdata/report.sql"` + } + type nestedViewOutput struct { + Data []*vendorRow `parameter:",kind=output,in=view" view:"vendor" sql:"uri=testdata/report.sql" anonymous:"true"` + } + type nestedViewRouteSource struct { + Route xdatly.Component[reportInput, nestedViewOutput] `component:",path=/v1/api/dev/vendors,method=GET"` + } + + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &nestedViewRouteSource{}}) + require.NoError(t, err) + + descriptors, ok := DescriptorsFrom(result) + require.True(t, ok) + + rootView := descriptors.ByPath["Route.Output.Data"] + require.NotNil(t, rootView) + require.True(t, rootView.HasViewTag) + + childView := descriptors.ByPath["Route.Output.Data.Products"] + require.NotNil(t, childView) + require.True(t, childView.HasViewTag) + require.NotNil(t, childView.ViewTag) + assert.Equal(t, "PRODUCT", childView.ViewTag.View.Table) + assert.Len(t, descriptors.ViewFields, 2) + assert.Nil(t, descriptors.ByPath["Route.Output.Data.Products"].StateTag) +} diff --git a/repository/shape/shape.go b/repository/shape/shape.go index 5f7f766d8..ef6f38042 100644 --- a/repository/shape/shape.go +++ b/repository/shape/shape.go @@ -21,6 +21,7 @@ type ( // Loader materializes runtime artifacts from normalized plan. Loader interface { LoadViews(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ViewArtifacts, error) + LoadResource(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ResourceArtifacts, error) LoadComponent(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ComponentArtifact, error) } @@ -35,9 +36,11 @@ type ( RegisterComponent(ctx context.Context, artifacts *ComponentArtifact) error } - ScanOptions struct{} - PlanOptions struct{} - LoadOptions struct{} + ScanOptions struct{} + PlanOptions struct{} + LoadOptions struct { + UseTypeContextPackages bool + } CompileOptions struct { Strict bool Profile CompileProfile @@ -50,6 +53,7 @@ type ( TypePackageName string TypePackagePath string InferTypeContext *bool + UseLinkedTypes *bool } ScanOption func(*ScanOptions) @@ -58,6 +62,15 @@ type ( CompileOption func(*CompileOptions) ) +func WithLoadTypeContextPackages(enabled bool) LoadOption { + return func(o *LoadOptions) { + if o == nil { + return + } + o.UseTypeContextPackages = enabled + } +} + const ( CompileMixedModeExecWins CompileMixedMode = "exec_wins" CompileMixedModeReadWins CompileMixedMode = "read_wins" @@ -94,11 +107,21 @@ func LoadComponent(ctx context.Context, src any, opts ...Option) (*ComponentArti return New(opts...).LoadComponent(ctx, src) } +// LoadResource is a package-level helper for struct source resource loading. +func LoadResource(ctx context.Context, src any, opts ...Option) (*ResourceArtifacts, error) { + return New(opts...).LoadResource(ctx, src) +} + // LoadDQLViews is a package-level helper for DQL source view loading. func LoadDQLViews(ctx context.Context, dql string, opts ...Option) (*ViewArtifacts, error) { return New(opts...).LoadDQLViews(ctx, dql) } +// LoadDQLResource is a package-level helper for DQL source resource loading. +func LoadDQLResource(ctx context.Context, dql string, opts ...Option) (*ResourceArtifacts, error) { + return New(opts...).LoadDQLResource(ctx, dql) +} + // LoadDQLComponent is a package-level helper for DQL source component loading. func LoadDQLComponent(ctx context.Context, dql string, opts ...Option) (*ComponentArtifact, error) { return New(opts...).LoadDQLComponent(ctx, dql) @@ -120,6 +143,22 @@ func (e *Engine) LoadViews(ctx context.Context, src any) (*ViewArtifacts, error) return e.options.Loader.LoadViews(ctx, plan) } +// LoadResource executes scan -> plan -> load for struct source. +func (e *Engine) LoadResource(ctx context.Context, src any) (*ResourceArtifacts, error) { + source, err := e.structSource(src) + if err != nil { + return nil, err + } + plan, err := e.scanAndPlan(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadResource(ctx, plan) +} + // LoadComponent executes scan -> plan -> load for struct source. func (e *Engine) LoadComponent(ctx context.Context, src any) (*ComponentArtifact, error) { source, err := e.structSource(src) @@ -152,6 +191,22 @@ func (e *Engine) LoadDQLViews(ctx context.Context, dql string) (*ViewArtifacts, return e.options.Loader.LoadViews(ctx, plan) } +// LoadDQLResource executes compile -> load for DQL source. +func (e *Engine) LoadDQLResource(ctx context.Context, dql string) (*ResourceArtifacts, error) { + source, err := e.dqlSource(dql) + if err != nil { + return nil, err + } + plan, err := e.compile(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadResource(ctx, plan) +} + // LoadDQLComponent executes compile -> load for DQL source. func (e *Engine) LoadDQLComponent(ctx context.Context, dql string) (*ComponentArtifact, error) { source, err := e.dqlSource(dql) diff --git a/repository/shape/validate/relation.go b/repository/shape/validate/relation.go index 31aee9357..c0774330c 100644 --- a/repository/shape/validate/relation.go +++ b/repository/shape/validate/relation.go @@ -36,6 +36,11 @@ func ValidateRelations(resource *view.Resource, targets ...*view.View) error { } } refIndex := view.Columns(ref.Columns).Index(ref.CaseFormat) + // Shape load runs before DB-backed column discovery in transcribe. + // When either side has no columns yet, defer strict relation checks. + if len(parent.Columns) == 0 || len(ref.Columns) == 0 { + continue + } pairCount := len(rel.On) if len(rel.Of.On) > pairCount { pairCount = len(rel.Of.On) diff --git a/repository/shape/velty/ast/assign.go b/repository/shape/velty/ast/assign.go new file mode 100644 index 000000000..f058ff248 --- /dev/null +++ b/repository/shape/velty/ast/assign.go @@ -0,0 +1,110 @@ +package ast + +import "fmt" + +func (s *Assign) Generate(builder *Builder) (err error) { + if builder.AssignNotifier != nil { + newExpr, err := builder.AssignNotifier(s) + if err != nil { + return err + } + + if newExpr != nil && newExpr != s { + return newExpr.Generate(builder) + } + } + + switch builder.Lang { + case LangVelty: + if err = builder.WriteIndentedString("\n#set("); err != nil { + return err + } + if err = s.Holder.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(" = "); err != nil { + return err + } + if err = s.Expression.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(")"); err != nil { + return err + } + return nil + + case LangGO: + + callExpr, ok := s.Expression.(*CallExpr) + if ok && callExpr.Name == "IndexBy" && builder.Options.Lang == LangGO { + if builder.IndexByCode == nil { + builder.IndexByCode = builder.NewBuilder() + } + indexBuilder := builder.IndexByCode + asIdent, _ := s.Holder.(*Ident) + if holder := asIdent.Holder; holder != "" { + indexBuilder.WriteString(holder + ".") + } + indexBuilder.WriteString(asIdent.Name) + indexBuilder.WriteString(" = ") + if err = s.Expression.Generate(indexBuilder); err != nil { + return err + } + indexBuilder.WriteString("\n") + return nil + } + + if err = builder.WriteIndentedString("\n"); err != nil { + return err + } + asIdent, ok := s.Holder.(*Ident) + wasDeclared := true + if ok { + wasDeclared = builder.State.IsDeclared(asIdent.Name) + } + + if err = s.Holder.Generate(builder); err != nil { + return err + } + + for _, holder := range s.ExtraHolders { + if err = builder.WriteString(", "); err != nil { + return err + } + + if err = holder.Generate(builder); err != nil { + return err + } + } + + if err = s.appendGoAssignToken(builder, wasDeclared); err != nil { + return err + } + + if err = s.Expression.Generate(builder); err != nil { + return err + } + if !wasDeclared { + builder.State.DeclareVariable(asIdent.Name) + } + return nil + } + return fmt.Errorf("unsupported option %T %v\n", s, builder.Lang) + +} + +func (s *Assign) appendGoAssignToken(builder *Builder, isDeclared bool) error { + + if isDeclared { + return builder.WriteString(" = ") + } + + return builder.WriteString(" := ") +} + +func NewAssign(holder Expression, expr Expression) *Assign { + return &Assign{ + Holder: holder, + Expression: expr, + } +} diff --git a/repository/shape/velty/ast/ast.go b/repository/shape/velty/ast/ast.go new file mode 100644 index 000000000..98290ecc7 --- /dev/null +++ b/repository/shape/velty/ast/ast.go @@ -0,0 +1,163 @@ +package ast + +import ( + "fmt" + "github.com/viant/tagly/format/text" +) + +type ( + Node interface { + Generate(builder *Builder) error + } + Statement interface { + Node + } + + Expression interface { + Node + } //can be BinaryExpr or CallExpr or QuerySelector Expr + + Block []Statement + Ident struct { + Holder string + Name string + } + + Foreach struct { + Value *Ident + Set *Ident + Body Block + } + + Assign struct { + Holder Expression + ExtraHolders []Expression + Expression Expression + } + + CallExpr struct { + Terminator bool + Receiver Expression + Name string + Args []Expression + } + + MapExpr struct { + Map Expression + Key Expression + } + + StatementExpression struct { + Expression + } + + TerminatorExpression struct { + X Expression + } + + SelectorExpr struct { + Ident + X *SelectorExpr + } + + BinaryExpr struct { + X Expression + Op string + Y Expression + } + + LiteralExpr struct { + Literal string + } + + ReturnExpr struct { + X Expression + } +) + +func NewReturnExpr(expr Expression) *ReturnExpr { + return &ReturnExpr{X: expr} +} +func (r *ReturnExpr) Generate(builder *Builder) error { + switch builder.Lang { + case LangGO: + if err := builder.WriteIndentedString("\nreturn "); err != nil { + return err + } + + return r.X.Generate(builder) + } + + return fmt.Errorf("unsupported %T with lang %v", r, builder.Lang) +} + +func (m *MapExpr) Generate(builder *Builder) error { + if err := m.Map.Generate(builder); err != nil { + return err + } + + if err := builder.WriteString("["); err != nil { + return err + } + + if err := m.Key.Generate(builder); err != nil { + return err + } + + if err := builder.WriteString("]"); err != nil { + return err + } + + return nil +} + +func (b *Block) Append(statement Statement) { + *b = append(*b, statement) +} + +func (b *Block) AppendEmptyLine() { + b.Append(NewStatementExpression(NewLiteral(""))) +} + +func (b Block) Generate(builder *Builder) error { + if builder.WithoutBusinessLogic { + return nil + } + for _, stmt := range b { + if err := stmt.Generate(builder); err != nil { + return err + } + } + return nil +} + +func (e Ident) Generate(builder *Builder) (err error) { + identName := e.Name + if builder.WithLowerCaseIdent && e.Holder == "" { + upperCamel := text.CaseFormatUpperCamel + identName = upperCamel.Format(identName, text.CaseFormatLowerCamel) + } + + if e.Holder != "" { + identName = e.Holder + "." + identName + } + builder.State.DeclareVariable(identName) + if builder.Lang == LangVelty { + return builder.WriteString("$" + identName) + } + return builder.WriteString(identName) +} + +func (b TerminatorExpression) Generate(builder *Builder) error { + if err := b.X.Generate(builder); err != nil { + return err + } + if builder.Lang == LangVelty { + return builder.WriteByte(';') + } + return nil +} + +func NewTerminatorExpression(x Expression) *TerminatorExpression { + return &TerminatorExpression{X: x} +} diff --git a/repository/shape/velty/ast/ast_test.go b/repository/shape/velty/ast/ast_test.go new file mode 100644 index 000000000..6d226b931 --- /dev/null +++ b/repository/shape/velty/ast/ast_test.go @@ -0,0 +1,124 @@ +package ast + +import ( + "github.com/stretchr/testify/assert" + "strings" + "testing" +) + +func TestBlock_Stringify(t *testing.T) { + + var testCases = []struct { + description string + block Block + options Options + expect string + }{ + { + description: "assign", + options: Options{Lang: LangVelty}, + block: Block{ + &Assign{Holder: &Ident{Name: "inited"}, Expression: &CallExpr{Receiver: Ident{Name: "Campaign"}, Name: "init", Args: []Expression{ + Ident{Name: "CurCampaign"}, + }}}, + }, + expect: `#set($inited = $Campaign.init($CurCampaign))`, + }, + { + description: "for each ", + options: Options{Lang: LangVelty}, + block: Block{ + &Foreach{Set: &Ident{Name: "Sets"}, + Value: &Ident{Name: "Item"}, + Body: Block{ + &Assign{Holder: &Ident{Name: "tested"}, Expression: &CallExpr{Receiver: Ident{Name: "Campaign"}, Name: "Test", Args: []Expression{ + Ident{Name: "Item"}, + }}}}}}, + expect: `#foreach($Item in $Sets) + #set($tested = $Campaign.Test($Item)) +#end`, + }, + { + description: "if condition", + options: Options{Lang: LangVelty}, + block: Block{ + &Condition{ + If: &BinaryExpr{X: &Ident{Name: "Campaign.Id"}, Op: ">", Y: &LiteralExpr{Literal: "1"}}, + IFBlock: Block{ + &Assign{Holder: &Ident{Name: "inited"}, Expression: &CallExpr{Receiver: Ident{Name: "Campaign"}, Name: "init", Args: []Expression{ + Ident{Name: "CurCampaign"}, + }}}, + }, + ElseIfBlocks: []*ConditionalBlock{{ + If: &BinaryExpr{X: &Ident{Name: "Campaign.Name"}, Op: "==", Y: &LiteralExpr{Literal: `"Foo"`}}, + Block: Block{ + &Assign{Holder: &Ident{Name: "fooed"}, Expression: &CallExpr{Receiver: Ident{Name: "Campaign"}, Name: "Foo", Args: []Expression{ + Ident{Name: "CurCampaign"}, + }}}, + }, + }, + }, + }, + }, + expect: `#if($Campaign.Id > 1) + #set($inited = $Campaign.init($CurCampaign)) +#elseif($Campaign.Name == "Foo") + #set($fooed = $Campaign.Foo($CurCampaign)) +#end`, + }, + { + description: "assign condition | go", + options: Options{Lang: LangGO}, + block: Block{ + &Assign{Holder: &Ident{Name: "foo"}, Expression: &LiteralExpr{Literal: "10"}}, + }, + expect: `foo := 10`, + }, + { + description: "if stmt | go", + options: Options{Lang: LangGO}, + block: Block{ + &Condition{ + If: &BinaryExpr{X: &LiteralExpr{"0"}, Y: &Ident{Name: "foo"}, Op: ">"}, + IFBlock: Block{ + &Assign{Holder: &Ident{Name: "foo"}, Expression: &BinaryExpr{X: &Ident{Name: "foo"}, Op: "*", Y: &LiteralExpr{Literal: "-1"}}}, + }, + }, + }, + expect: `if 0 > foo { + foo = foo * -1 +}`, + }, + { + description: "foreach", + options: Options{Lang: LangGO}, + block: Block{ + &Foreach{ + Value: &Ident{Name: "foo"}, + Set: &Ident{Name: "foos"}, + Body: Block{ + &CallExpr{ + Receiver: &Ident{Name: "fmt"}, + Name: "Printf", + Args: []Expression{&Ident{Name: "foo"}}, + }, + }, + }, + }, + expect: `for _, foo := range foos { + fmt.Printf(foo) +}`, + }, + } + + //for _, testCase := range testCases[len(testCases)-1:] { + for _, testCase := range testCases { + builder := NewBuilder(testCase.options) + err := testCase.block.Generate(builder) + if !assert.Nil(t, err, testCase.description) { + continue + } + actual := builder.String() + assert.EqualValues(t, testCase.expect, strings.TrimSpace(actual)) + } +} diff --git a/repository/shape/velty/ast/binary.go b/repository/shape/velty/ast/binary.go new file mode 100644 index 000000000..cb4231e96 --- /dev/null +++ b/repository/shape/velty/ast/binary.go @@ -0,0 +1,22 @@ +package ast + +func NewBinary(x Expression, op string, y Expression) *BinaryExpr { + return &BinaryExpr{X: x, Op: op, Y: y} +} + +func (s *BinaryExpr) Generate(builder *Builder) (err error) { + if err := s.X.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(" "); err != nil { + return err + } + + if err = builder.WriteString(s.Op); err != nil { + return err + } + if err = builder.WriteString(" "); err != nil { + return err + } + return s.Y.Generate(builder) +} diff --git a/repository/shape/velty/ast/builder.go b/repository/shape/velty/ast/builder.go new file mode 100644 index 000000000..53fd18d25 --- /dev/null +++ b/repository/shape/velty/ast/builder.go @@ -0,0 +1,52 @@ +package ast + +import "strings" + +const ( + LangVelty = "velty" + LangGO = "go" +) + +type ( + Builder struct { + *strings.Builder + Options + Indent string + State *Scope + declarations map[string]string + IndexByCode *Builder + } +) + +func (b *Builder) NewBuilder() *Builder { + r := *b + r.Builder = &strings.Builder{} + return &r +} + +func (b *Builder) WriteIndentedString(s string) error { + fragment := strings.ReplaceAll(s, "\n", "\n"+b.Indent) + _, err := b.Builder.WriteString(fragment) + return err +} + +func (b *Builder) IncIndent(indent string) *Builder { + newBuilder := *b + newBuilder.Indent += indent + newBuilder.State = newBuilder.State.NextScope() + return &newBuilder +} + +func (b *Builder) WriteString(s string) error { + _, err := b.Builder.WriteString(s) + return err +} + +func NewBuilder(option Options, declaredVariables ...string) *Builder { + return &Builder{ + Builder: &strings.Builder{}, + Options: option, + declarations: map[string]string{}, + State: NewScope(declaredVariables...), + } +} diff --git a/repository/shape/velty/ast/condition.go b/repository/shape/velty/ast/condition.go new file mode 100644 index 000000000..60b88b3a5 --- /dev/null +++ b/repository/shape/velty/ast/condition.go @@ -0,0 +1,142 @@ +package ast + +import ( + "fmt" +) + +type ( + Condition struct { + If Expression + IFBlock Block + ElseIfBlocks []*ConditionalBlock + ElseBlock Block + } + + ConditionalBlock struct { + If Expression + Block Block + } +) + +func (s *Condition) Generate(builder *Builder) (err error) { + if builder.OnIfNotifier != nil { + if expr, err := builder.OnIfNotifier(s); err != nil { + return err + } else if expr != nil && expr != s { + return expr.Generate(builder) + } + } + + switch builder.Lang { + case LangVelty: + if err = builder.WriteIndentedString("\n#if("); err != nil { + return err + } + if err = s.If.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(")"); err != nil { + return err + } + bodyBuilder := builder.IncIndent(" ") + if err = s.IFBlock.Generate(bodyBuilder); err != nil { + return err + } + for _, item := range s.ElseIfBlocks { + if err = builder.WriteIndentedString("\n#elseif("); err != nil { + return err + } + if err = item.If.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(")"); err != nil { + return err + } + if err = item.Block.Generate(bodyBuilder); err != nil { + return err + } + } + if s.ElseBlock != nil { + if err = builder.WriteIndentedString("\n#else"); err != nil { + return err + } + if err = s.ElseBlock.Generate(bodyBuilder); err != nil { + return err + } + } + if err = builder.WriteIndentedString("\n#end"); err != nil { + return err + } + return nil + + case LangGO: + if err = builder.WriteIndentedString("\nif "); err != nil { + return err + } + + if err = s.If.Generate(builder); err != nil { + return err + } + + if err = builder.WriteString(" {"); err != nil { + return err + } + + bodyBlockBuilder := builder.IncIndent(" ") + if err = s.IFBlock.Generate(bodyBlockBuilder); err != nil { + return err + } + + if err = builder.WriteIndentedString("\n}"); err != nil { + return err + } + + for _, block := range s.ElseIfBlocks { + if err = builder.WriteString(" else if "); err != nil { + return err + } + + if err = block.If.Generate(builder); err != nil { + return err + } + + if err = builder.WriteString(" { "); err != nil { + return err + } + + if err = block.Block.Generate(bodyBlockBuilder); err != nil { + return err + } + + if err = builder.WriteIndentedString("\n} "); err != nil { + return err + } + } + + if len(s.ElseBlock) > 0 { + if err = builder.WriteString(" else "); err != nil { + return err + } + + if err = builder.WriteString(" { "); err != nil { + return err + } + + if err = s.ElseBlock.Generate(bodyBlockBuilder); err != nil { + return err + } + + if err = builder.WriteIndentedString("\n} "); err != nil { + return err + } + } + + return nil + } + + return fmt.Errorf("unsupported option %T %v\n", s, builder.Lang) +} + +func NewCondition(ifExpr Expression, ifBlock, elseBlock Block) *Condition { + return &Condition{If: ifExpr, IFBlock: ifBlock, ElseBlock: elseBlock} +} diff --git a/repository/shape/velty/ast/dml.go b/repository/shape/velty/ast/dml.go new file mode 100644 index 000000000..e6c333d15 --- /dev/null +++ b/repository/shape/velty/ast/dml.go @@ -0,0 +1,97 @@ +package ast + +import ( + "fmt" + "strings" +) + +type Insert struct { + Table string + Columns []string + Fields []string +} + +func (s *Insert) Generate(builder *Builder) (err error) { + switch builder.Lang { + case LangVelty: + builder.WriteString("INSERT INTO ") + builder.WriteString(s.Table) + builder.WriteString("(") + builder.WriteString(strings.Join(s.Columns, ",")) + builder.WriteString(") Fields(") + builder.WriteString(strings.Join(s.Fields, ",")) + builder.WriteString(");") + case LangGO: + return fmt.Errorf("DML not yet supported for golang") + } + return nil +} + +type Update struct { + Table string + Columns []string + Fields []string + PkColumns []string + PkFields []string +} + +func (s *Update) Generate(builder *Builder) (err error) { + switch builder.Lang { + case LangVelty: + if err = builder.WriteString("UPDATE "); err != nil { + return err + } + if err = builder.WriteString(s.Table); err != nil { + return err + } + if err = builder.WriteString("SET "); err != nil { + return err + } + for i, column := range s.PkColumns { + if i > 0 { + if err = builder.WriteString(","); err != nil { + return err + } + } + if err = builder.WriteString(column); err != nil { + return err + } + if err = builder.WriteString(" = "); err != nil { + return err + } + if err = builder.WriteString(s.PkFields[i]); err != nil { + return err + } + } + for i, column := range s.Columns { + if err = builder.WriteString("\t#if("); err == nil { + if err = builder.WriteString(getHasField(s.Fields[i])); err == nil { + if err = builder.WriteString(")"); err == nil { + if err = builder.WriteString(","); err == nil { + if err = builder.WriteString(column); err == nil { + if err = builder.WriteString(" = "); err == nil { + if err = builder.WriteString(s.Fields[i]); err == nil { + err = builder.WriteString("\t#end") + } + } + } + } + } + } + } + } + return err + case LangGO: + return fmt.Errorf("DML not yet supported for golang") + } + return nil +} + +func getHasField(field string) string { + if index := strings.LastIndex(field, "."); index != -1 { + leaf := field[index+1:] + field = field[:index] + return field + "." + "Has." + leaf + } + return field +} diff --git a/repository/shape/velty/ast/errcheck.go b/repository/shape/velty/ast/errcheck.go new file mode 100644 index 000000000..d225faa2d --- /dev/null +++ b/repository/shape/velty/ast/errcheck.go @@ -0,0 +1,29 @@ +package ast + +type ErrorCheck struct { + X Expression +} + +func (e *ErrorCheck) Generate(builder *Builder) error { + switch builder.Options.Lang { + case LangGO: + if err := builder.WriteString("if err ="); err != nil { + return err + } + if err := e.X.Generate(builder); err != nil { + return err + } + return builder.WriteString(";err != nil {\nreturn err\n}") + case LangVelty: + if err := builder.WriteString("\n"); err != nil { + return err + } + return e.X.Generate(builder) + } + + return unsupportedOptionUse(builder, e) +} + +func NewErrorCheck(expr Expression) *ErrorCheck { + return &ErrorCheck{X: expr} +} diff --git a/repository/shape/velty/ast/expression.go b/repository/shape/velty/ast/expression.go new file mode 100644 index 000000000..6c1f8889c --- /dev/null +++ b/repository/shape/velty/ast/expression.go @@ -0,0 +1,95 @@ +package ast + +import "fmt" + +func NewCallExpr(holder Expression, name string, args ...Expression) *CallExpr { + return &CallExpr{ + Receiver: holder, + Name: name, + Args: args, + } +} + +func (s *StatementExpression) Generate(builder *Builder) (err error) { + if err = builder.WriteIndentedString("\n"); err != nil { + return err + } + return s.Expression.Generate(builder) +} + +// NewStatementExpression return new statement expr +func NewStatementExpression(expr Expression) *StatementExpression { + return &StatementExpression{Expression: expr} +} +func (e *CallExpr) Generate(builder *Builder) (err error) { + expr, err := e.actualExpr(builder) + if err != nil { + return err + } + if expr != e { + return expr.Generate(builder) + } + + if e.Receiver != nil { + + if err = e.Receiver.Generate(builder); err != nil { + return err + } + + if err = builder.WriteString("."); err != nil { + return err + } + } + if err = builder.WriteString(e.Name); err != nil { + return err + } + + if err = builder.WriteString("("); err != nil { + return err + } + for i, arg := range e.Args { + if i > 0 { + if err = builder.WriteString(", "); err != nil { + return err + } + } + if err = arg.Generate(builder); err != nil { + return err + } + } + if err = builder.WriteString(")"); err != nil { + return err + } + + return nil +} + +func (e *CallExpr) actualExpr(builder *Builder) (Expression, error) { + if builder.CallNotifier == nil { + return e, nil + } + + notifier, err := builder.CallNotifier(e) + if err != nil || notifier != nil { + return notifier, err + } + + return e, nil +} + +func (s *SelectorExpr) Generate(builder *Builder) error { + return unsupportedOptionUse(builder, s) +} + +func unsupportedOptionUse(builder *Builder, s Expression) error { + return fmt.Errorf("unsupported option %T %v\n", s, builder.Lang) +} + +func NewIdent(name string) *Ident { + return &Ident{Name: name} +} + +func NewHolderIndent(holder, name string) *Ident { + ret := &Ident{Name: name, Holder: holder} + return ret +} diff --git a/repository/shape/velty/ast/foreach.go b/repository/shape/velty/ast/foreach.go new file mode 100644 index 000000000..2e096cde7 --- /dev/null +++ b/repository/shape/velty/ast/foreach.go @@ -0,0 +1,86 @@ +package ast + +import "fmt" + +func (s *Foreach) Generate(builder *Builder) (err error) { + if builder.SliceItemNotifier != nil { + if err = builder.SliceItemNotifier(s.Value, s.Set); err != nil { + return err + } + } + + switch builder.Lang { + case LangVelty: + if err = builder.WriteIndentedString("\n#foreach("); err != nil { + return err + } + if err = s.Value.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(" in "); err != nil { + return err + } + if err = s.Set.Generate(builder); err != nil { + return err + } + if err = builder.WriteString(")"); err != nil { + return err + } + + bodyBuilder := builder.IncIndent(" ") + if err = s.Body.Generate(bodyBuilder); err != nil { + return err + } + if err = builder.WriteIndentedString("\n#end"); err != nil { + return err + } + return nil + + case LangGO: + + if err = builder.WriteIndentedString("\nfor _, "); err != nil { + return err + } + + if err = s.Value.Generate(builder); err != nil { + return err + } + + if err = builder.WriteString(" := range "); err != nil { + return err + } + + if err = s.Set.Generate(builder); err != nil { + return err + } + + if err = builder.WriteString(" { "); err != nil { + return err + } + + bodyBuilder := builder.IncIndent(" ") + if err = bodyBuilder.WriteIndentedString("\n"); err != nil { + return err + } + + if err = s.Body.Generate(bodyBuilder); err != nil { + return err + } + + if err = builder.WriteString("\n}"); err != nil { + return err + } + + return nil + } + + return fmt.Errorf("unsupported option %T %v\n", s, builder.Lang) +} + +func NewForEach(value, set *Ident, body Block) *Foreach { + return &Foreach{ + Value: value, + Set: set, + Body: body, + } +} diff --git a/repository/shape/velty/ast/func.go b/repository/shape/velty/ast/func.go new file mode 100644 index 000000000..28307ee58 --- /dev/null +++ b/repository/shape/velty/ast/func.go @@ -0,0 +1,162 @@ +package ast + +type ( + Function struct { + Receiver *Receiver + Name string + ArgsIn []*FuncArg + ArgsOut []string + Body Block + Return *ReturnExpr + } + + Receiver struct { + Name string + Ident *Ident + } + + FuncArg struct { + Name string + Ident *Ident + } +) + +func (a *FuncArg) Generate(builder *Builder) error { + switch builder.Lang { + case LangGO: + if err := builder.WriteString(a.Name); err != nil { + return err + } + + if err := builder.WriteString(" "); err != nil { + return err + } + + return a.Ident.Generate(builder) + } + + return unsupportedOptionUse(builder, a) +} + +func (r *Receiver) Generate(builder *Builder) error { + switch builder.Lang { + case LangGO: + if err := builder.WriteString(r.Name); err != nil { + return err + } + + if err := builder.WriteString(" "); err != nil { + return err + } + + return r.Ident.Generate(builder) + } + + return unsupportedOptionUse(builder, r) +} + +func (f *Function) Generate(builder *Builder) error { + switch builder.Lang { + case LangGO: + if err := builder.WriteIndentedString("\nfunc "); err != nil { + return err + } + + if f.Receiver != nil { + if err := builder.WriteString("( "); err != nil { + return err + } + + if err := f.Receiver.Generate(builder); err != nil { + return err + } + + if err := builder.WriteString(" ) "); err != nil { + return err + } + } + + if err := builder.WriteString(f.Name); err != nil { + return err + } + + if err := builder.WriteString("("); err != nil { + return err + } + + for i, arg := range f.ArgsIn { + if i != 0 { + if err := builder.WriteString(", "); err != nil { + return err + } + } + + if err := arg.Generate(builder); err != nil { + return err + } + } + + if err := builder.WriteString(") "); err != nil { + return err + } + + switch len(f.ArgsOut) { + case 0: + //Exec nothing + case 1: + if err := builder.WriteString(f.ArgsOut[0]); err != nil { + return err + } + + default: + for i, argType := range f.ArgsOut { + if err := builder.WriteString("("); err != nil { + return err + } + + if i != 0 { + if err := builder.WriteString(", "); err != nil { + return err + } + + } + + if err := builder.WriteString(argType); err != nil { + return err + } + + if err := builder.WriteString(")"); err != nil { + return err + } + } + } + + if err := builder.WriteString(" {"); err != nil { + return err + } + + blockBuilder := builder.IncIndent(" ") + if err := blockBuilder.WriteIndentedString("\n"); err != nil { + return err + } + + if err := f.Body.Generate(blockBuilder); err != nil { + return err + } + + if f.Return != nil { + if err := f.Return.Generate(builder); err != nil { + return err + } + } + + if err := builder.WriteIndentedString("\n}"); err != nil { + return err + } + + return nil + + default: + return unsupportedOptionUse(builder, f) + } +} diff --git a/repository/shape/velty/ast/literal.go b/repository/shape/velty/ast/literal.go new file mode 100644 index 000000000..04b2a7fe5 --- /dev/null +++ b/repository/shape/velty/ast/literal.go @@ -0,0 +1,21 @@ +package ast + +import ( + "strconv" + "strings" +) + +func (s *LiteralExpr) Generate(builder *Builder) error { + return builder.WriteString(s.Literal) +} + +func NewQuotedLiteral(text string) *LiteralExpr { + if !strings.HasPrefix(text, "\"") { + text = strconv.Quote(text) + } + return &LiteralExpr{text} +} + +func NewLiteral(text string) *LiteralExpr { + return &LiteralExpr{text} +} diff --git a/repository/shape/velty/ast/options.go b/repository/shape/velty/ast/options.go new file mode 100644 index 000000000..e03146fd6 --- /dev/null +++ b/repository/shape/velty/ast/options.go @@ -0,0 +1,12 @@ +package ast + +type Options struct { + Lang string + StateName string + CallNotifier func(callExpr *CallExpr) (Expression, error) + AssignNotifier func(assign *Assign) (Expression, error) + SliceItemNotifier func(value, set *Ident) error + WithoutBusinessLogic bool + OnIfNotifier func(value *Condition) (Expression, error) + WithLowerCaseIdent bool +} diff --git a/repository/shape/velty/ast/scope.go b/repository/shape/velty/ast/scope.go new file mode 100644 index 000000000..82cb06fdd --- /dev/null +++ b/repository/shape/velty/ast/scope.go @@ -0,0 +1,63 @@ +package ast + +import "strings" + +type ( + Scope struct { + Variables map[string]*Variable + Parent *Scope + } + + Variable struct { + Name string + } +) + +func NewScope(declaredVariables ...string) *Scope { + variables := map[string]*Variable{} + for _, variable := range declaredVariables { + variables[variable] = &Variable{Name: variable} + } + return &Scope{ + Variables: variables, + } +} + +func (s *Scope) NextScope() *Scope { + scope := NewScope() + scope.Parent = s + return scope +} + +func (s *Scope) DeclareVariable(variable string) { + split := strings.Split(variable, ".") + if len(split) > 0 { + variable = split[0] + } + + if s.Variables[variable] != nil { + return + } + + s.Variables[variable] = &Variable{ + Name: variable, + } +} + +func (s *Scope) IsDeclared(variable string) bool { + dotIndex := strings.Index(variable, ".") + if dotIndex >= 0 { + return true + } + + tmp := s + for tmp != nil { + if _, ok := tmp.Variables[variable]; ok { + return true + } + + tmp = tmp.Parent + } + + return false +} diff --git a/repository/shape/velty/ast/star.go b/repository/shape/velty/ast/star.go new file mode 100644 index 000000000..769b03374 --- /dev/null +++ b/repository/shape/velty/ast/star.go @@ -0,0 +1,51 @@ +package ast + +type ( + DerefExpression struct { + X Expression + } + + RefExpression struct { + X Expression + } +) + +func NewRefExpression(x Expression) *RefExpression { + return &RefExpression{ + X: x, + } +} + +func NewDerefExpression(x Expression) *DerefExpression { + return &DerefExpression{X: x} +} + +func (s *DerefExpression) Generate(builder *Builder) error { + switch builder.Options.Lang { + case LangGO: + if err := builder.WriteString("*"); err != nil { + return err + } + + fallthrough + case LangVelty: + return s.X.Generate(builder) + } + + return unsupportedOptionUse(builder, s) +} + +func (s *RefExpression) Generate(builder *Builder) error { + switch builder.Options.Lang { + case LangGO: + if err := builder.WriteString("&"); err != nil { + return err + } + + fallthrough + case LangVelty: + return s.X.Generate(builder) + } + + return unsupportedOptionUse(builder, s) +} diff --git a/repository/shape/xgen/codegen.go b/repository/shape/xgen/codegen.go new file mode 100644 index 000000000..6716ea7ae --- /dev/null +++ b/repository/shape/xgen/codegen.go @@ -0,0 +1,3598 @@ +package xgen + +import ( + "bytes" + "context" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "path" + "path/filepath" + "sort" + "strings" + "time" + "unicode" + "unicode/utf8" + + "github.com/viant/datly/repository/shape/dql/shape" + shapeload "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/typectx" + utypes "github.com/viant/datly/utils/types" + "github.com/viant/datly/view" + "github.com/viant/datly/view/extension" + "github.com/viant/datly/view/state" + viewtags "github.com/viant/datly/view/tags" + "github.com/viant/tagly/format/text" + "github.com/viant/xreflect" + "github.com/viant/xunsafe" + "reflect" +) + +// ComponentCodegen generates Go source code for a complete component package: +// Input struct, Output struct, entity view structs, init() registration, +// //go:embed directive, and DefineComponent function. +// +// This is the shape pipeline equivalent of repository.Component.GenerateOutputCode. +type ComponentCodegen struct { + Component *shapeload.Component + Resource *view.Resource + TypeContext *typectx.Context + ProjectDir string + PackageDir string + PackageName string + PackagePath string + FileName string // defaults to .go + WithEmbed bool // generate //go:embed and EmbedFS method + WithContract bool // generate DefineComponent function + WithRegister *bool // generate init() with core.RegisterType (default: true) +} + +// ComponentCodegenResult captures generation outputs. +type ComponentCodegenResult struct { + FilePath string + PackageDir string + PackagePath string + PackageName string + Types []string + GeneratedFiles []string + InputFilePath string + OutputFilePath string + ViewFilePath string + RouterFilePath string + VeltyFilePath string + Embeds map[string]string // SQL file name → SQL content +} + +type codegenSelectorHolder struct { + FieldName string + QuerySelector string + Type reflect.Type +} + +// Generate produces the component Go source file. +func (g *ComponentCodegen) Generate() (*ComponentCodegenResult, error) { + if g.Component == nil { + return nil, fmt.Errorf("shape codegen: nil component") + } + if g.Resource == nil { + return nil, fmt.Errorf("shape codegen: nil resource") + } + + projectDir := g.ProjectDir + if projectDir == "" { + return nil, fmt.Errorf("shape codegen: project dir required") + } + + packageDir := g.PackageDir + if packageDir == "" && g.TypeContext != nil { + packageDir = g.TypeContext.PackageDir + } + if packageDir == "" { + return nil, fmt.Errorf("shape codegen: package dir required") + } + if !filepath.IsAbs(packageDir) { + packageDir = filepath.Join(projectDir, packageDir) + } + + packageName := g.PackageName + if packageName == "" && g.TypeContext != nil { + packageName = g.TypeContext.PackageName + } + if packageName == "" { + packageName = filepath.Base(packageDir) + } + + packagePath := g.PackagePath + if packagePath == "" && g.TypeContext != nil { + packagePath = g.TypeContext.PackagePath + } + + componentName := g.componentName() + inputTypeName := g.inputTypeName(componentName) + outputTypeName := g.outputTypeName(componentName) + rootViewTypeName := g.rootViewTypeName(componentName) + embedURI := text.CaseFormatUpperCamel.Format(componentName, text.CaseFormatLowerUnderscore) + explicitOutputParams := cloneCodegenParameters(g.Component.OutputParameters()) + hasExplicitOutput := len(explicitOutputParams) > 0 + + defaultFileName := g.FileName + if defaultFileName == "" { + defaultFileName = embedURI + ".go" + } + outputFileName := g.resolveOutputDestFileName(defaultFileName) + inputFileName := g.resolveInputDestFileName(outputFileName) + viewFileName := g.resolveViewDestFileName(outputFileName) + routerFileName := g.resolveRouterDestFileName("") + + shapeFragment, err := g.generateShapeFragment(projectDir, packageDir, packageName, packagePath) + if err != nil { + return nil, err + } + + // Build Input/Output types using state.Parameters.ReflectType + lookupType := g.componentLookupType(packagePath) + + var inputType, outputType reflect.Type + var selectorHolders []codegenSelectorHolder + inputParams := state.Parameters(nil) + if params := g.codegenInputParameters(); len(params) > 0 || strings.TrimSpace(g.Component.URI) != "" { + normalized := params + inputParams, selectorHolders = g.partitionInputParametersForCodegen(normalized, packagePath, lookupType) + normalizeBodyInputTypesForCodegen(inputParams, packagePath, lookupType) + inputOpts := []state.ReflectOption{state.WithSetMarker(), state.WithTypeName(inputTypeName)} + if g.componentUsesVelty() { + inputOpts = append(inputOpts, state.WithVelty(true)) + } + rt, err := inputParams.ReflectType(packagePath, lookupType, inputOpts...) + if err == nil && rt != nil { + inputType = rt + } + } + + // Build output parameters — use explicit ones or synthesize defaults for readers + outputParams := cloneCodegenParameters(explicitOutputParams) + if !hasExplicitOutput { + outputParams = g.defaultOutputParameters(componentName) + } + g.syncOutputSummarySchemasForCodegen(outputParams) + // Resolve wildcard output types to the view entity type + g.resolveOutputWildcardTypes(outputParams, componentName) + if len(outputParams) > 0 { + rt, err := outputParams.ReflectType(packagePath, lookupType) + if err == nil && rt != nil { + outputType = rt + } + } + + shapeTypeNames := map[string]bool{} + if shapeFragment != nil { + for _, typeName := range shapeFragment.Types { + typeName = strings.TrimSpace(typeName) + if typeName != "" { + shapeTypeNames[typeName] = true + } + } + } + + inputHelpers := collectNamedHelperTypes(inputType, packagePath, shapeTypeNames) + selectorHelpers := []namedHelperType{} + selectorTypeImports := []string{} + for _, holder := range selectorHolders { + if holder.Type == nil { + continue + } + selectorHelpers = append(selectorHelpers, collectNamedHelperTypes(holder.Type, packagePath, shapeTypeNames)...) + selectorTypeImports = mergeImportPaths(selectorTypeImports, collectTypeImports(holder.Type, packagePath)) + } + outputHelpers := collectNamedHelperTypes(outputType, packagePath, shapeTypeNames) + mutableSupport := g.mutableSupport(inputType) + emitResponseImport := g.outputUsesResponse(outputParams) || mutableSupport != nil + mutableOutputImports := []string{} + if mutableSupport != nil { + mutableOutputImports = append(mutableOutputImports, "github.com/viant/xdatly/handler/validator") + } + + var initBuilder strings.Builder + // init() registration + initBuilder.WriteString("func init() {\n") + if g.withRegister() { + registryPackage := strings.TrimSpace(packagePath) + if registryPackage == "" { + registryPackage = packageName + } + registered := map[string]bool{} + if inputType != nil { + initBuilder.WriteString(fmt.Sprintf("\tcore.RegisterType(%q, %q, reflect.TypeOf(%s{}), checksum.GeneratedTime)\n", + registryPackage, inputTypeName, inputTypeName)) + registered[inputTypeName] = true + } + initBuilder.WriteString(fmt.Sprintf("\tcore.RegisterType(%q, %q, reflect.TypeOf(%s{}), checksum.GeneratedTime)\n", + registryPackage, outputTypeName, outputTypeName)) + registered[outputTypeName] = true + if shapeFragment != nil { + for _, typeName := range shapeFragment.Types { + typeName = strings.TrimSpace(typeName) + if typeName == "" || registered[typeName] { + continue + } + initBuilder.WriteString(fmt.Sprintf("\tcore.RegisterType(%q, %q, reflect.TypeOf(%s{}), checksum.GeneratedTime)\n", + registryPackage, typeName, typeName)) + registered[typeName] = true + } + } + for _, helper := range inputHelpers { + if helper.TypeName == "" || registered[helper.TypeName] { + continue + } + initBuilder.WriteString(fmt.Sprintf("\tcore.RegisterType(%q, %q, reflect.TypeOf(%s{}), checksum.GeneratedTime)\n", + registryPackage, helper.TypeName, helper.TypeName)) + registered[helper.TypeName] = true + } + for _, helper := range outputHelpers { + if helper.TypeName == "" || registered[helper.TypeName] { + continue + } + initBuilder.WriteString(fmt.Sprintf("\tcore.RegisterType(%q, %q, reflect.TypeOf(%s{}), checksum.GeneratedTime)\n", + registryPackage, helper.TypeName, helper.TypeName)) + registered[helper.TypeName] = true + } + for _, helper := range selectorHelpers { + if helper.TypeName == "" || registered[helper.TypeName] { + continue + } + initBuilder.WriteString(fmt.Sprintf("\tcore.RegisterType(%q, %q, reflect.TypeOf(%s{}), checksum.GeneratedTime)\n", + registryPackage, helper.TypeName, helper.TypeName)) + registered[helper.TypeName] = true + } + } + initBuilder.WriteString("}\n\n") + + var inputBuilder strings.Builder + if inputType != nil || g.WithContract { + inputBuilder.WriteString(fmt.Sprintf("type %s struct {\n", inputTypeName)) + if inputType != nil { + inputBuilder.WriteString(structFieldsSource(inputType)) + } + if mutableSupport != nil { + mutableSupport.renderInputFields(&inputBuilder) + } + inputBuilder.WriteString("}\n\n") + } + for _, helper := range inputHelpers { + inputBuilder.WriteString(helper.Decl) + } + if g.WithEmbed && inputType != nil { + inputBuilder.WriteString(fmt.Sprintf("func (i *%s) EmbedFS() *embed.FS {\n", inputTypeName)) + inputBuilder.WriteString(fmt.Sprintf("\treturn &%sFS\n", componentName)) + inputBuilder.WriteString("}\n\n") + } + + var outputBuilder strings.Builder + var routerBuilder strings.Builder + if g.WithEmbed { + outputBuilder.WriteString(fmt.Sprintf("//go:embed %s/*.sql\n", embedURI)) + outputBuilder.WriteString(fmt.Sprintf("var %sFS embed.FS\n\n", componentName)) + } + outputRenderParams := cloneCodegenParameters(explicitOutputParams) + if !hasExplicitOutput { + outputRenderParams = g.defaultOutputParameters(componentName) + } + g.resolveOutputWildcardTypes(outputRenderParams, componentName) + g.renderOutputStruct(&outputBuilder, outputTypeName, rootViewTypeName, embedURI, outputRenderParams, outputType, mutableSupport) + for _, helper := range outputHelpers { + outputBuilder.WriteString(helper.Decl) + } + if g.WithContract { + g.renderComponentHolder(&routerBuilder, componentName, inputTypeName, outputTypeName, selectorHolders) + for _, helper := range selectorHelpers { + routerBuilder.WriteString(helper.Decl) + } + g.renderDefineComponent(&outputBuilder, componentName, inputTypeName, outputTypeName) + } + + viewDecls := "" + viewImports := []string{} + if shapeFragment != nil { + viewDecls = strings.TrimSpace(shapeFragment.TypeDecls) + viewImports = shapeFragment.Imports + } + + outputParts := []string{initBuilder.String(), outputBuilder.String()} + routerParts := []string{} + if strings.TrimSpace(routerBuilder.String()) != "" { + if routerFileName == "" || routerFileName == outputFileName { + outputParts = append(outputParts, routerBuilder.String()) + } else { + routerParts = append(routerParts, routerBuilder.String()) + } + } + inputParts := []string{inputBuilder.String()} + viewParts := []string{} + if viewDecls != "" { + viewParts = append(viewParts, viewDecls+"\n") + } + + if err := os.MkdirAll(packageDir, 0o755); err != nil { + return nil, err + } + outputDest := filepath.Join(packageDir, outputFileName) + inputDest := outputDest + viewDest := outputDest + routerDest := "" + inputInitDest := "" + inputValidateDest := "" + if inputFileName != "" { + inputDest = filepath.Join(packageDir, inputFileName) + } + if viewFileName != "" { + viewDest = filepath.Join(packageDir, viewFileName) + } + if routerFileName != "" { + routerDest = filepath.Join(packageDir, routerFileName) + } + if mutableSupport != nil && inputDest != "" { + base := "input" + if inputFileName != "" && inputFileName != outputFileName { + base = strings.TrimSuffix(filepath.Base(inputDest), filepath.Ext(inputDest)) + } else if g.Component != nil && g.Component.Directives != nil { + if dest := strings.TrimSpace(g.Component.Directives.InputDest); dest != "" { + base = strings.TrimSuffix(filepath.Base(dest), filepath.Ext(dest)) + } + } + if strings.TrimSpace(base) == "" { + base = "input" + } + inputInitDest = filepath.Join(packageDir, base+"_init.go") + inputValidateDest = filepath.Join(packageDir, base+"_validate.go") + } + var generatedFiles []string + appendGenerated := func(dest string) { + dest = strings.TrimSpace(dest) + if dest == "" { + return + } + for _, candidate := range generatedFiles { + if candidate == dest { + return + } + } + generatedFiles = append(generatedFiles, dest) + } + split := outputFileName != inputFileName || outputFileName != viewFileName || (routerFileName != "" && routerFileName != outputFileName) + var writeErr error + if !split { + imports := mergeImportPaths( + g.buildImports(g.WithContract && (routerFileName == "" || routerFileName == outputFileName), emitResponseImport), + viewImports, + collectTypeImports(inputType, packagePath), + collectTypeImports(outputType, packagePath), + selectorTypeImports, + helperImports(inputHelpers), + helperImports(selectorHelpers), + helperImports(outputHelpers), + mutableOutputImports, + ) + combined := append(append(outputParts, inputParts...), viewParts...) + writeErr = g.writeSectionFile(outputDest, packageName, imports, combined...) + if writeErr == nil { + appendGenerated(outputDest) + outputFileName = outputDest + } + } else { + outputImports := mergeImportPaths( + g.buildImports(g.WithContract && (routerFileName == "" || routerFileName == outputFileName), emitResponseImport), + collectTypeImports(outputType, packagePath), + helperImports(outputHelpers), + mutableOutputImports, + ) + if routerFileName == "" || routerFileName == outputFileName { + outputImports = mergeImportPaths(outputImports, selectorTypeImports, helperImports(selectorHelpers)) + } + if viewFileName == outputFileName { + outputImports = mergeImportPaths(outputImports, viewImports) + } + if writeErr = g.writeSectionFile(outputDest, packageName, outputImports, outputParts...); writeErr != nil { + return nil, writeErr + } + appendGenerated(outputDest) + if inputFileName == outputFileName { + if writeErr = g.appendSectionFile(outputDest, inputParts...); writeErr != nil { + return nil, writeErr + } + } else if strings.TrimSpace(strings.Join(inputParts, "")) != "" { + inputImports := []string{} + if g.WithEmbed && inputType != nil { + inputImports = append(inputImports, "embed") + } + inputImports = mergeImportPaths(inputImports, collectTypeImports(inputType, packagePath), helperImports(inputHelpers)) + if viewFileName == inputFileName { + inputImports = mergeImportPaths(inputImports, viewImports) + } + if writeErr = g.writeSectionFile(inputDest, packageName, inputImports, inputParts...); writeErr != nil { + return nil, writeErr + } + appendGenerated(inputDest) + } + if len(viewParts) > 0 { + if viewFileName == outputFileName { + if writeErr = g.appendSectionFile(outputDest, viewParts...); writeErr != nil { + return nil, writeErr + } + } else if viewFileName == inputFileName { + if _, statErr := os.Stat(inputDest); statErr == nil { + if writeErr = g.appendSectionFile(inputDest, viewParts...); writeErr != nil { + return nil, writeErr + } + } else { + if writeErr = g.writeSectionFile(inputDest, packageName, viewImports, viewParts...); writeErr != nil { + return nil, writeErr + } + appendGenerated(inputDest) + } + } else { + if writeErr = g.writeSectionFile(viewDest, packageName, viewImports, viewParts...); writeErr != nil { + return nil, writeErr + } + appendGenerated(viewDest) + } + } + if len(routerParts) > 0 { + routerImports := mergeImportPaths(g.buildRouterImports(), selectorTypeImports, helperImports(selectorHelpers)) + if writeErr = g.writeSectionFile(routerDest, packageName, routerImports, routerParts...); writeErr != nil { + return nil, writeErr + } + appendGenerated(routerDest) + } + outputFileName = outputDest + } + if writeErr != nil { + return nil, writeErr + } + if mutableSupport != nil { + if writeErr = g.writeSectionFile(inputInitDest, packageName, []string{"context", "github.com/viant/xdatly/handler"}, mutableSupport.renderInputInit(inputTypeName, outputTypeName)); writeErr != nil { + return nil, writeErr + } + appendGenerated(inputInitDest) + if writeErr = g.writeSectionFile(inputValidateDest, packageName, []string{"context", "github.com/viant/xdatly/handler", "github.com/viant/xdatly/handler/validator"}, mutableSupport.renderInputValidate(inputTypeName, outputTypeName)); writeErr != nil { + return nil, writeErr + } + appendGenerated(inputValidateDest) + } + veltyDest := "" + if mutableSupport != nil && !g.componentUsesHandler() { + var veltyBody string + var ok bool + veltyBody, ok, writeErr = g.renderMutableDSQL(inputType) + if writeErr != nil { + return nil, writeErr + } + if ok { + veltyDest = filepath.Join(packageDir, text.CaseFormatUpperCamel.Format(componentName, text.CaseFormatLowerUnderscore), "patch.sql") + if writeErr = os.MkdirAll(filepath.Dir(veltyDest), 0o755); writeErr != nil { + return nil, writeErr + } + if writeErr = os.WriteFile(veltyDest, []byte(veltyBody), 0o644); writeErr != nil { + return nil, writeErr + } + appendGenerated(veltyDest) + for _, helperFile := range g.mutableHelperSQLFiles(mutableSupport) { + if strings.TrimSpace(helperFile.Path) == "" || strings.TrimSpace(helperFile.Content) == "" { + continue + } + if writeErr = os.MkdirAll(filepath.Dir(helperFile.Path), 0o755); writeErr != nil { + return nil, writeErr + } + if writeErr = os.WriteFile(helperFile.Path, []byte(helperFile.Content), 0o644); writeErr != nil { + return nil, writeErr + } + appendGenerated(helperFile.Path) + } + } + } + + var typeNames []string + if inputType != nil { + typeNames = append(typeNames, inputTypeName) + } + if outputType != nil { + typeNames = append(typeNames, outputTypeName) + } + if shapeFragment != nil { + typeNames = append(typeNames, shapeFragment.Types...) + } + + return &ComponentCodegenResult{ + FilePath: outputFileName, + PackageDir: packageDir, + PackagePath: packagePath, + PackageName: packageName, + Types: typeNames, + GeneratedFiles: generatedFiles, + InputFilePath: inputDest, + OutputFilePath: outputDest, + ViewFilePath: viewDest, + RouterFilePath: routerDest, + VeltyFilePath: veltyDest, + }, nil +} + +func normalizeBodyInputTypesForCodegen(params state.Parameters, pkgPath string, lookupType xreflect.LookupType) { + for _, param := range params { + if param == nil || param.In == nil || param.In.Kind != state.KindRequestBody || param.Schema == nil { + continue + } + if param.Schema.Cardinality != state.One { + continue + } + rType := param.Schema.Type() + if rType == nil { + if resolved, err := utypes.LookupType(lookupType, param.Schema.DataType, xreflect.WithPackage(param.Schema.Package)); err == nil && resolved != nil { + rType = resolved + } else if resolved, err := utypes.LookupType(lookupType, param.Schema.DataType, xreflect.WithPackage(pkgPath)); err == nil && resolved != nil { + rType = resolved + } + } + if rType != nil && rType.Kind() == reflect.Struct { + param.Schema.SetType(reflect.PtrTo(rType)) + } + } +} + +func (g *ComponentCodegen) refreshSummarySchemasForCodegen() { + if g == nil || g.Resource == nil { + return + } + visited := map[*view.View]bool{} + for _, aView := range g.Resource.Views { + g.refreshViewSummarySchemasForCodegen(context.Background(), aView, visited) + } +} + +func (g *ComponentCodegen) refreshViewSummarySchemasForCodegen(ctx context.Context, aView *view.View, visited map[*view.View]bool) { + if aView == nil || visited[aView] { + return + } + visited[aView] = true + if aView.Template != nil && aView.Template.Summary != nil { + _ = aView.Template.Init(ctx, g.Resource, aView) + } + for _, rel := range aView.With { + if rel == nil || rel.Of == nil { + continue + } + g.refreshViewSummarySchemasForCodegen(ctx, &rel.Of.View, visited) + } +} + +func (g *ComponentCodegen) syncOutputSummarySchemasForCodegen(params state.Parameters) { + root := g.rootResourceView() + if root == nil || root.Template == nil || root.Template.Summary == nil || root.Template.Summary.Schema == nil { + return + } + for _, param := range params { + if param == nil || param.In == nil || param.In.Name != "summary" { + continue + } + param.Schema = root.Template.Summary.Schema.Clone() + } +} + +func normalizeInputParametersForCodegen(params state.Parameters, resource *view.Resource, uri string) state.Parameters { + result := make(state.Parameters, 0, len(params)+4) + seenPath := map[string]bool{} + var stateResource state.Resource + if resource != nil { + stateResource = view.NewResources(resource, &view.View{}) + } + for _, item := range params { + if item == nil { + continue + } + cloned := *item + schema := normalizeInputSchemaForCodegen(item.Name, item.In, item.Required != nil && *item.Required, item.Schema, resource) + cloned.Schema = schema + if cloned.Schema != nil && stateResource != nil { + _ = cloned.Schema.Init(stateResource) + if cloned.In != nil && cloned.In.Kind == state.KindRequestBody && cloned.Schema.Cardinality == state.One { + normalizeBodySchemaPointerForCodegen(cloned.Schema) + } + } + if cloned.Output != nil { + output := *cloned.Output + if cloned.Output.Schema != nil { + output.Schema = cloned.Output.Schema.Clone() + } + cloned.Output = &output + if stateResource != nil && cloned.Schema != nil && cloned.Schema.Type() != nil { + _ = cloned.Output.Init(stateResource, cloned.Schema.Type()) + } + } + if in := item.In; in != nil && in.Kind == state.KindView { + viewName := strings.TrimSpace(item.Name) + if name := strings.TrimSpace(in.Name); name != "" { + viewName = name + } + if v := lookupInputView(resource, viewName); v != nil { + cloned.Tag = mergeViewSQLTag(cloned.Tag, v) + } + } + cloned.Tag = ensureCodegenTypeNameTag(cloned.Tag, cloned.Schema) + if in := cloned.In; in != nil && in.Kind == state.KindPath { + key := strings.ToLower(strings.TrimSpace(in.Name)) + if key == "" { + key = strings.ToLower(strings.TrimSpace(cloned.Name)) + } + if key != "" { + seenPath[key] = true + } + } + result = append(result, &cloned) + } + for _, name := range extractCodegenRoutePathParams(uri) { + key := strings.ToLower(strings.TrimSpace(name)) + if key == "" || seenPath[key] { + continue + } + fieldName := name + result = append(result, &state.Parameter{ + Name: fieldName, + In: state.NewPathLocation(name), + Schema: &state.Schema{ + DataType: "string", + Cardinality: state.One, + }, + }) + seenPath[key] = true + } + return result +} + +func (g *ComponentCodegen) codegenInputParameters() state.Parameters { + if g == nil || g.Component == nil { + return nil + } + params := cloneCodegenParameters(g.Component.InputParameters()) + params = g.mergeMutableTemplateInputParametersForCodegen(params) + return normalizeInputParametersForCodegen(params, g.Resource, g.Component.URI) +} + +func (g *ComponentCodegen) mergeMutableTemplateInputParametersForCodegen(params state.Parameters) state.Parameters { + if g == nil || !g.componentUsesVelty() { + return params + } + root := g.rootResourceView() + if root == nil || root.Template == nil || !root.Template.UseParameterStateType || len(root.Template.Parameters) == 0 { + return params + } + result := cloneCodegenParameters(params) + seen := map[string]bool{} + for _, item := range result { + if item == nil { + continue + } + seen[codegenParameterKey(item)] = true + } + for _, item := range root.Template.Parameters { + if item == nil { + continue + } + key := codegenParameterKey(item) + if seen[key] { + continue + } + cloned := *item + if item.Schema != nil { + cloned.Schema = item.Schema.Clone() + } + if item.Output != nil { + output := *item.Output + if item.Output.Schema != nil { + output.Schema = item.Output.Schema.Clone() + } + cloned.Output = &output + } + result = append(result, &cloned) + seen[key] = true + } + return result +} + +func exportedCodegenParamName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return strings.ToUpper(name[:1]) + name[1:] +} + +func extractCodegenRoutePathParams(uri string) []string { + uri = strings.TrimSpace(uri) + if uri == "" { + return nil + } + var result []string + seen := map[string]bool{} + for { + start := strings.IndexByte(uri, '{') + if start == -1 { + break + } + uri = uri[start+1:] + end := strings.IndexByte(uri, '}') + if end == -1 { + break + } + name := strings.TrimSpace(uri[:end]) + uri = uri[end+1:] + if name == "" { + continue + } + key := strings.ToLower(name) + if seen[key] { + continue + } + seen[key] = true + result = append(result, name) + } + return result +} + +func cloneCodegenParameters(params state.Parameters) state.Parameters { + if len(params) == 0 { + return nil + } + result := make(state.Parameters, 0, len(params)) + for _, item := range params { + if item == nil { + continue + } + cloned := *item + if item.Schema != nil { + cloned.Schema = item.Schema.Clone() + } + if item.Output != nil { + output := *item.Output + if item.Output.Schema != nil { + output.Schema = item.Output.Schema.Clone() + } + cloned.Output = &output + } + result = append(result, &cloned) + } + return result +} + +func (g *ComponentCodegen) partitionInputParametersForCodegen(params state.Parameters, packagePath string, lookupType xreflect.LookupType) (state.Parameters, []codegenSelectorHolder) { + if len(params) == 0 { + return nil, nil + } + selectorByKey := map[string]string{} + if g != nil && g.Component != nil { + for _, item := range g.Component.Input { + if item == nil || strings.TrimSpace(item.QuerySelector) == "" { + continue + } + selectorByKey[codegenParameterKey(&item.Parameter)] = strings.TrimSpace(item.QuerySelector) + } + } + + business := make(state.Parameters, 0, len(params)) + grouped := map[string]state.Parameters{} + order := []string{} + for _, item := range params { + if item == nil { + continue + } + querySelector := selectorByKey[codegenParameterKey(item)] + if querySelector == "" { + business = append(business, item) + continue + } + if _, ok := grouped[querySelector]; !ok { + order = append(order, querySelector) + } + grouped[querySelector] = append(grouped[querySelector], item) + } + + if len(order) == 0 { + return business, nil + } + + holders := make([]codegenSelectorHolder, 0, len(order)) + usedNames := map[string]bool{} + for i, querySelector := range order { + group := grouped[querySelector] + holderType, err := group.ReflectType(packagePath, lookupType) + if err != nil || holderType == nil { + business = append(business, group...) + continue + } + holders = append(holders, codegenSelectorHolder{ + FieldName: selectorHolderFieldName(querySelector, i, len(order), usedNames), + QuerySelector: querySelector, + Type: holderType, + }) + } + return business, holders +} + +func codegenParameterKey(param *state.Parameter) string { + if param == nil { + return "" + } + kind := "" + inName := "" + if param.In != nil { + kind = strings.ToLower(strings.TrimSpace(string(param.In.Kind))) + inName = strings.ToLower(strings.TrimSpace(param.In.Name)) + } + return strings.ToLower(strings.TrimSpace(param.Name)) + "|" + kind + "|" + inName +} + +func selectorHolderFieldName(querySelector string, index, total int, used map[string]bool) string { + name := "ViewSelect" + if total > 1 { + base := toUpperCamel(querySelector) + if base != "" { + name = base + "Select" + } else { + name = fmt.Sprintf("ViewSelect%d", index+1) + } + } + candidate := name + if used == nil { + return candidate + } + for suffix := 2; used[candidate]; suffix++ { + candidate = fmt.Sprintf("%s%d", name, suffix) + } + used[candidate] = true + return candidate +} + +func normalizeInputSchemaForCodegen(paramName string, in *state.Location, required bool, schema *state.Schema, resource *view.Resource) *state.Schema { + var cloned state.Schema + if schema != nil { + cloned = exportedSchemaCopy(schema) + } + kind := state.Kind("") + if in != nil { + kind = in.Kind + } + if kind == state.KindView { + if viewSchema := lookupViewSchemaForInput(resource, in, paramName); viewSchema != nil { + base := exportedSchemaCopy(viewSchema) + if explicit := strings.TrimSpace(cloned.Name); explicit != "" && strings.TrimSpace(base.Name) == "" { + base.Name = explicit + } + if explicit := strings.TrimSpace(cloned.DataType); explicit != "" && !isDynamicTypeName(explicit) && strings.TrimSpace(base.DataType) == "" { + base.DataType = explicit + } + if explicit := strings.TrimSpace(cloned.Package); explicit != "" && strings.TrimSpace(base.Package) == "" { + base.Package = explicit + } + if explicit := strings.TrimSpace(cloned.PackagePath); explicit != "" && strings.TrimSpace(base.PackagePath) == "" { + base.PackagePath = explicit + } + if explicit := strings.TrimSpace(cloned.ModulePath); explicit != "" && strings.TrimSpace(base.ModulePath) == "" { + base.ModulePath = explicit + } + if explicit := cloned.Cardinality; explicit != "" { + base.Cardinality = explicit + } + cloned = base + } + } + if cloned.Cardinality == "" { + if kind == state.KindView { + if required { + cloned.Cardinality = state.One + } else { + cloned.Cardinality = state.Many + } + } else { + cloned.Cardinality = state.One + } + } + if kind != state.KindView && strings.TrimSpace(cloned.DataType) == "" { + cloned.DataType = "string" + } + if kind == state.KindRequestBody && cloned.Cardinality == state.One { + normalizeBodySchemaPointerForCodegen(&cloned) + } + return &cloned +} + +func normalizeBodySchemaPointerForCodegen(schema *state.Schema) { + if schema == nil { + return + } + if rType := schema.Type(); rType != nil { + for rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + return + } + if rType.Kind() == reflect.Ptr { + return + } + if rType.Kind() == reflect.Struct { + schema.SetType(reflect.PtrTo(rType)) + } + } + dataType := strings.TrimSpace(schema.DataType) + if dataType == "" || strings.HasPrefix(dataType, "*") || strings.HasPrefix(dataType, "[]") { + return + } + if strings.HasPrefix(dataType, "struct {") || strings.HasPrefix(dataType, "interface{") || dataType == "string" || dataType == "int" || dataType == "bool" || dataType == "float64" { + return + } + schema.DataType = "*" + dataType +} + +func exportedSchemaCopy(schema *state.Schema) state.Schema { + if schema == nil { + return state.Schema{} + } + result := state.Schema{ + Package: schema.Package, + PackagePath: schema.PackagePath, + ModulePath: schema.ModulePath, + Name: schema.Name, + DataType: schema.DataType, + Cardinality: schema.Cardinality, + Methods: append([]reflect.Method(nil), schema.Methods...), + } + if rType := schema.Type(); rType != nil { + result.SetType(rType) + if schema.Package != "" { + result.Package = schema.Package + } + if schema.PackagePath != "" { + result.PackagePath = schema.PackagePath + } + if schema.ModulePath != "" { + result.ModulePath = schema.ModulePath + } + } + return result +} + +func lookupViewSchemaForInput(resource *view.Resource, in *state.Location, paramName string) *state.Schema { + if v := lookupInputView(resource, strings.TrimSpace(paramName)); v != nil && v.Schema != nil { + return v.Schema + } + if in != nil { + if v := lookupInputView(resource, strings.TrimSpace(in.Name)); v != nil && v.Schema != nil { + return v.Schema + } + } + return nil +} + +func lookupInputView(resource *view.Resource, name string) *view.View { + if resource == nil { + return nil + } + name = normalizeViewLookupName(name) + if name == "" { + return nil + } + for _, item := range resource.Views { + if item == nil { + continue + } + candidates := []string{ + item.Name, + item.Reference.Ref, + } + if item.Schema != nil { + candidates = append(candidates, item.Schema.Name) + } + for _, candidate := range candidates { + if normalizeViewLookupName(candidate) == name { + return item + } + } + } + return nil +} + +func normalizeViewLookupName(value string) string { + value = strings.TrimSpace(strings.ToLower(value)) + if value == "" { + return "" + } + var ret strings.Builder + ret.Grow(len(value)) + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + ret.WriteRune(r) + } + } + return ret.String() +} + +func mergeViewSQLTag(existing string, aView *view.View) string { + tag := buildViewMetadataTag(aView, true, true) + if tag == nil { + return existing + } + return string(tag.UpdateTag(reflect.StructTag(existing))) +} + +func buildViewMetadataTag(aView *view.View, includeName bool, includeSQL bool) *viewtags.Tag { + if aView == nil { + return nil + } + result := &viewtags.Tag{} + tagView := &viewtags.View{} + if includeName { + tagView.Name = strings.TrimSpace(aView.Name) + } + if table := strings.TrimSpace(aView.Table); isStableTableName(table) { + tagView.Table = table + } + if aView.Template != nil && aView.Template.Summary != nil { + tagView.SummaryURI = strings.TrimSpace(aView.Template.Summary.SourceURL) + } + if aView.Groupable { + value := true + tagView.Groupable = &value + } + if aView.Batch != nil && aView.Batch.Size > 0 && aView.Batch.Size != 10000 { + tagView.Batch = aView.Batch.Size + } + if aView.RelationalConcurrency != nil && aView.RelationalConcurrency.Number > 0 && aView.RelationalConcurrency.Number != 1 { + tagView.RelationalConcurrency = aView.RelationalConcurrency.Number + } + if aView.PublishParent { + tagView.PublishParent = true + } + if aView.Partitioned != nil { + tagView.PartitionerType = aView.Partitioned.DataType + tagView.PartitionedConcurrency = aView.Partitioned.Concurrency + } + if aView.MatchStrategy != "" && aView.MatchStrategy != view.ReadMatched { + tagView.Match = string(aView.MatchStrategy) + } + if aView.Cache != nil { + tagView.Cache = strings.TrimSpace(aView.Cache.Reference.Ref) + } + if aView.Connector != nil && aView.Connector.Ref != "" { + tagView.Connector = aView.Connector.Ref + } + if selector := aView.Selector; selector != nil { + if ns := strings.TrimSpace(selector.Namespace); ns != "" { + tagView.SelectorNamespace = ns + } + if selector.NoLimit || selector.Limit != 0 { + limit := selector.Limit + tagView.Limit = &limit + } + if constraints := selector.Constraints; constraints != nil { + if constraints.Criteria { + value := true + tagView.SelectorCriteria = &value + } + if constraints.Projection { + value := true + tagView.SelectorProjection = &value + } + if constraints.OrderBy { + value := true + tagView.SelectorOrderBy = &value + } + if constraints.Offset { + value := true + tagView.SelectorOffset = &value + } + if constraints.Page != nil { + value := *constraints.Page + tagView.SelectorPage = &value + } + if len(constraints.Filterable) > 0 { + tagView.SelectorFilterable = append([]string(nil), constraints.Filterable...) + } + if len(constraints.OrderByColumn) > 0 { + tagView.SelectorOrderByColumns = map[string]string{} + for key, value := range constraints.OrderByColumn { + tagView.SelectorOrderByColumns[key] = value + } + } + } + } + if aView.Tag != "" { + tagView.CustomTag = aView.Tag + } + if tagView.Name != "" || tagView.Table != "" || tagView.SummaryURI != "" || tagView.CustomTag != "" || tagView.Connector != "" || + tagView.Cache != "" || tagView.Limit != nil || tagView.Match != "" || tagView.Batch > 0 || + tagView.PublishParent || tagView.PartitionerType != "" || tagView.RelationalConcurrency > 0 || + tagView.Groupable != nil || tagView.SelectorNamespace != "" || tagView.SelectorCriteria != nil || + tagView.SelectorProjection != nil || tagView.SelectorOrderBy != nil || tagView.SelectorOffset != nil || + tagView.SelectorPage != nil || len(tagView.SelectorFilterable) > 0 || len(tagView.SelectorOrderByColumns) > 0 { + result.View = tagView + } + if includeSQL && aView.Template != nil { + if sourceURL := strings.TrimSpace(aView.Template.SourceURL); sourceURL != "" { + result.SQL = viewtags.NewViewSQL("", sourceURL) + } + } + if result.View == nil && result.SQL.URI == "" && result.SQL.SQL == "" { + return nil + } + return result +} + +func removeTagKeys(tag string, keys ...string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + for _, key := range keys { + var updated string + updated, _ = xreflect.RemoveTag(tag, key) + tag = strings.TrimSpace(updated) + } + return tag +} + +func ensureCodegenTypeNameTag(tag string, schema *state.Schema) string { + if schema == nil { + return strings.TrimSpace(tag) + } + typeName := strings.TrimSpace(schema.Name) + if typeName == "" { + return strings.TrimSpace(tag) + } + tag = removeTagKeys(tag, "typeName") + tag = strings.TrimSpace(tag) + if tag == "" { + return fmt.Sprintf(`typeName:"%s"`, typeName) + } + return tag + ` typeName:"` + typeName + `"` +} + +func isDynamicTypeName(name string) bool { + n := strings.TrimSpace(strings.ToLower(name)) + n = strings.ReplaceAll(n, " ", "") + switch n { + case "", "interface{}", "any", "*interface{}", "[]interface{}", "[]any": + return true + } + return false +} + +func (g *ComponentCodegen) componentLookupType(packagePath string) xreflect.LookupType { + localTypes := map[string]reflect.Type{} + if g != nil && g.Resource != nil { + for _, aView := range g.Resource.Views { + if aView == nil { + continue + } + typeName := "" + if aView.Schema != nil { + typeName = strings.TrimSpace(aView.Schema.Name) + } + if typeName == "" { + typeName = toUpperCamel(strings.TrimSpace(aView.Name)) + "View" + } + if typeName == "" { + continue + } + var rType reflect.Type + if aView.Schema != nil && aView.Schema.Type() != nil { + rType = aView.Schema.Type() + } + if rType == nil && len(aView.Columns) > 0 { + rType = buildStructType(columnsFromView(aView), g.viewUsesVelty(aView)) + } + if rType == nil { + continue + } + key := strings.ToLower(typeName) + localTypes[key] = rType + if summary := summaryTemplateOf(aView); summary != nil && summary.Schema != nil { + if summaryType := summary.Schema.Type(); summaryType != nil { + summaryName := strings.TrimSpace(summary.Schema.Name) + if summaryName == "" { + summaryName = strings.TrimSpace(summary.Name) + } + if summaryName != "" { + localTypes[strings.ToLower(summaryName)] = summaryType + } + } + } + } + } + return func(name string, opts ...xreflect.Option) (reflect.Type, error) { + base := normalizeLookupTypeName(name) + if base != "" { + if rType, ok := localTypes[strings.ToLower(base)]; ok { + return rType, nil + } + if packagePath != "" { + if linked := xunsafe.LookupType(packagePath + "/" + base); linked != nil { + return linked, nil + } + } + } + if builtin, ok := builtinTypeByName(name); ok { + return builtin, nil + } + if builtin, ok := builtinTypeByName(base); ok { + return builtin, nil + } + return nil, fmt.Errorf("type %s not found", name) + } +} + +func builtinTypeByName(name string) (reflect.Type, bool) { + name = strings.TrimSpace(name) + if name == "" { + return nil, false + } + if strings.HasPrefix(name, "[]") { + if elem, ok := builtinTypeByName(strings.TrimPrefix(name, "[]")); ok { + return reflect.SliceOf(elem), true + } + } + if strings.HasPrefix(name, "*") { + if elem, ok := builtinTypeByName(strings.TrimPrefix(name, "*")); ok { + return reflect.PtrTo(elem), true + } + } + switch name { + case "string": + return reflect.TypeOf(""), true + case "bool": + return reflect.TypeOf(true), true + case "int": + return reflect.TypeOf(int(0)), true + case "int8": + return reflect.TypeOf(int8(0)), true + case "int16": + return reflect.TypeOf(int16(0)), true + case "int32": + return reflect.TypeOf(int32(0)), true + case "int64": + return reflect.TypeOf(int64(0)), true + case "uint": + return reflect.TypeOf(uint(0)), true + case "uint8": + return reflect.TypeOf(uint8(0)), true + case "uint16": + return reflect.TypeOf(uint16(0)), true + case "uint32": + return reflect.TypeOf(uint32(0)), true + case "uint64": + return reflect.TypeOf(uint64(0)), true + case "float32": + return reflect.TypeOf(float32(0)), true + case "float64": + return reflect.TypeOf(float64(0)), true + case "time.Time": + return reflect.TypeOf(time.Time{}), true + } + return nil, false +} + +func normalizeLookupTypeName(name string) string { + name = strings.TrimSpace(name) + for strings.HasPrefix(name, "*") || strings.HasPrefix(name, "[]") { + if strings.HasPrefix(name, "*") { + name = strings.TrimPrefix(name, "*") + continue + } + name = strings.TrimPrefix(name, "[]") + } + if idx := strings.LastIndex(name, "."); idx != -1 { + name = name[idx+1:] + } + return strings.TrimSpace(name) +} + +func columnsFromView(aView *view.View) []columnDescriptor { + result := make([]columnDescriptor, 0, len(aView.Columns)) + for _, col := range aView.Columns { + if col == nil { + continue + } + result = append(result, columnDescriptor{ + name: strings.TrimSpace(col.Name), + dataType: strings.TrimSpace(col.DataType), + nullable: col.Nullable, + }) + } + return result +} + +func toUpperCamel(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + var b strings.Builder + capNext := true + for _, r := range s { + if r == '_' || r == '-' || r == ' ' || r == '.' || r == '/' { + capNext = true + continue + } + if capNext { + b.WriteRune(unicode.ToUpper(r)) + capNext = false + continue + } + b.WriteRune(r) + } + return b.String() +} + +type shapeFragment struct { + Types []string + Imports []string + TypeDecls string +} + +func (g *ComponentCodegen) generateShapeFragment(projectDir, packageDir, packageName, packagePath string) (*shapeFragment, error) { + if g == nil || g.Resource == nil || len(g.Resource.Views) == 0 { + return &shapeFragment{}, nil + } + shapeDoc := resourceToShapeDocument(g.Resource, g.TypeContext) + applyShapeDocViewTypeOverrides(shapeDoc.Root, g.Component) + shapeCfg := &Config{ + ProjectDir: projectDir, + PackageDir: packageDir, + PackageName: packageName, + PackagePath: packagePath, + } + if overrides := collectViewTypeOverrides(g.Component); len(overrides) > 0 { + shapeCfg.ViewTypeNamer = func(ctx ViewTypeContext) string { + if value := strings.TrimSpace(overrides[strings.ToLower(strings.TrimSpace(ctx.ViewName))]); value != "" { + return value + } + return "" + } + } + hydrateConfigFromTypeContext(shapeDoc, shapeCfg) + applyDefaults(shapeCfg) + return g.renderSemanticShapeFragment(shapeCfg, packagePath) +} + +func (g *ComponentCodegen) renderSemanticShapeFragment(shapeCfg *Config, packagePath string) (*shapeFragment, error) { + viewDescriptorsByName := map[string]viewDescriptor{} + shapeDoc := resourceToShapeDocument(g.Resource, g.TypeContext) + for _, item := range extractViews(shapeDoc.Root) { + viewDescriptorsByName[strings.ToLower(strings.TrimSpace(asString(item.name)))] = item + } + typeNames := make([]string, 0, len(g.Resource.Views)) + registered := map[string]bool{} + imports := map[string]bool{} + var decls strings.Builder + for _, aView := range g.Resource.Views { + if aView == nil { + continue + } + typeName := g.resourceViewTypeName(shapeCfg, aView) + if typeName == "" || registered[typeName] { + continue + } + mutable := false + if descriptor, ok := viewDescriptorsByName[strings.ToLower(strings.TrimSpace(aView.Name))]; ok { + mutable = descriptor.mutable + } + viewDecl, viewImports, err := g.renderSemanticViewDecl(shapeCfg, aView, packagePath, mutable) + if err != nil { + return nil, err + } + if strings.TrimSpace(viewDecl) == "" { + continue + } + registered[typeName] = true + typeNames = append(typeNames, typeName) + decls.WriteString(viewDecl) + decls.WriteString("\n") + for _, imp := range viewImports { + imports[imp] = true + } + for _, summary := range g.summaryTypeDecls(aView, packagePath) { + if registered[summary.name] { + continue + } + registered[summary.name] = true + typeNames = append(typeNames, summary.name) + decls.WriteString(summary.decl) + decls.WriteString("\n") + for _, imp := range summary.imports { + imports[imp] = true + } + } + + if descriptor, ok := viewDescriptorsByName[strings.ToLower(strings.TrimSpace(aView.Name))]; ok && descriptor.mutable { + structType := buildHasType(columnsFromView(aView)) + if structType != nil { + hasTypeName := typeName + "Has" + if !registered[hasTypeName] { + registered[hasTypeName] = true + typeNames = append(typeNames, hasTypeName) + decls.WriteString(fmt.Sprintf("type %s struct {\n", hasTypeName)) + decls.WriteString(structFieldsSource(structType)) + decls.WriteString("}\n\n") + } + } + } + } + mergedImports := make([]string, 0, len(imports)) + for imp := range imports { + mergedImports = append(mergedImports, imp) + } + sort.Strings(mergedImports) + return &shapeFragment{ + Types: typeNames, + Imports: mergeImportPaths(mergedImports), + TypeDecls: strings.TrimSpace(decls.String()), + }, nil +} + +type emittedTypeDecl struct { + name string + decl string + imports []string +} + +func (g *ComponentCodegen) summaryTypeDecls(aView *view.View, currentPackage string) []emittedTypeDecl { + if aView == nil { + return nil + } + seen := map[string]bool{} + var result []emittedTypeDecl + appendSummary := func(summary *view.TemplateSummary) { + if summary == nil || summary.Schema == nil { + return + } + name := strings.TrimSpace(summary.Schema.Name) + rType := ensureCodegenStructType(summary.Schema.Type()) + if name == "" || rType == nil || seen[name] { + return + } + seen[name] = true + result = append(result, emittedTypeDecl{ + name: name, + decl: fmt.Sprintf("type %s struct {\n%s}\n\n", name, structFieldsSource(rType)), + imports: collectTypeImports(rType, currentPackage), + }) + } + appendSummary(summaryTemplateOf(aView)) + for _, rel := range aView.With { + child := g.semanticView(g.resolveRelationView(rel)) + appendSummary(summaryTemplateOf(child)) + } + return result +} + +func summaryTemplateOf(aView *view.View) *view.TemplateSummary { + if aView == nil || aView.Template == nil { + return nil + } + return aView.Template.Summary +} + +func (g *ComponentCodegen) resourceViewTypeName(shapeCfg *Config, aView *view.View) string { + if aView == nil { + return "" + } + descriptor := viewDescriptor{ + name: aView.Name, + schemaName: "", + columns: columnsFromView(aView), + } + if aView.Schema != nil { + descriptor.schemaName = aView.Schema.Name + } + return viewTypeName(shapeCfg, descriptor) +} + +func (g *ComponentCodegen) renderSemanticViewDecl(shapeCfg *Config, aView *view.View, currentPackage string, mutable bool) (string, []string, error) { + aView = g.semanticView(aView) + typeName := g.resourceViewTypeName(shapeCfg, aView) + if typeName == "" { + return "", nil, nil + } + var builder strings.Builder + builder.WriteString(fmt.Sprintf("type %s struct {\n", typeName)) + imports := map[string]bool{} + emittedFields := map[string]bool{} + appendField := func(fieldName, fieldSrc string, fieldImports []string) { + if strings.TrimSpace(fieldName) == "" { + fieldName = renderedFieldName(fieldSrc) + } + fieldName = strings.TrimSpace(fieldName) + if fieldName == "" || emittedFields[fieldName] || strings.TrimSpace(fieldSrc) == "" { + return + } + emittedFields[fieldName] = true + builder.WriteString(fieldSrc) + for _, imp := range fieldImports { + imports[imp] = true + } + } + renderedScalar := false + for _, column := range aView.Columns { + fieldSrc, fieldImports := g.renderColumnField(aView, column, currentPackage) + if fieldSrc == "" { + continue + } + renderedScalar = true + appendField("", fieldSrc, fieldImports) + } + if !renderedScalar { + for _, field := range g.renderScalarFallbackFields(aView, currentPackage) { + fieldName := strings.TrimSpace(strings.Split(strings.TrimSpace(field.src), " ")[0]) + appendField(fieldName, field.src, field.imports) + } + } + for _, rel := range aView.With { + fieldSrc, fieldImports := g.renderRelationField(shapeCfg, aView, rel, currentPackage) + if fieldSrc == "" { + continue + } + appendField(strings.TrimSpace(rel.Holder), fieldSrc, fieldImports) + if metaSrc, metaImports := g.renderRelationSummaryField(shapeCfg, rel, currentPackage); metaSrc != "" { + fieldName := "" + if rel.Of.Template != nil && rel.Of.Template.Summary != nil { + fieldName = state.StructFieldName(text.CaseFormatUpperCamel, rel.Of.Template.Summary.Name) + } + appendField(fieldName, metaSrc, metaImports) + } + } + if aView.SelfReference != nil { + if holder := strings.TrimSpace(aView.SelfReference.Holder); holder != "" { + builder.WriteString(fmt.Sprintf("\t%s []interface{} `sqlx:\"-\"`\n", holder)) + } + } + if mutable { + hasTypeName := typeName + "Has" + builder.WriteString(fmt.Sprintf("\tHas *%s `setMarker:\"true\" format:\"-\" sqlx:\"-\" diff:\"-\" json:\"-\" typeName:\"%s\"`\n", hasTypeName, hasTypeName)) + } + builder.WriteString("}\n\n") + resultImports := make([]string, 0, len(imports)) + for imp := range imports { + resultImports = append(resultImports, imp) + } + sort.Strings(resultImports) + return builder.String(), resultImports, nil +} + +func (g *ComponentCodegen) renderColumnField(aView *view.View, column *view.Column, currentPackage string) (string, []string) { + if aView == nil || column == nil { + return "", nil + } + fieldName := column.FieldName() + if strings.TrimSpace(fieldName) == "" { + caseFormat := aView.CaseFormat + if !caseFormat.IsDefined() { + caseFormat = text.CaseFormatLowerUnderscore + } + fieldName = state.StructFieldName(caseFormat, column.Name) + } + rType := column.ColumnType() + if rType == nil { + if builtin, ok := builtinTypeByName(column.DataType); ok { + rType = builtin + } else if g != nil && g.Resource != nil { + if lookup := g.Resource.LookupType(); lookup != nil { + if resolved, err := utypes.LookupType(lookup, column.DataType); err == nil && resolved != nil { + rType = resolved + } + } + if rType == nil && extension.Config != nil && extension.Config.Types != nil { + if resolved, err := utypes.LookupType(extension.Config.Types.Lookup, column.DataType); err == nil && resolved != nil { + rType = resolved + } + } + } + } + if rType == nil { + rType = reflect.TypeOf((*interface{})(nil)).Elem() + } + rType = g.normalizeColumnType(column, rType) + tag := g.columnFieldTag(aView, column) + return fmt.Sprintf("\t%s %s `%s`\n", fieldName, goTypeString(rType), tag), collectTypeImports(rType, currentPackage) +} + +func (g *ComponentCodegen) renderRelationField(shapeCfg *Config, parent *view.View, rel *view.Relation, currentPackage string) (string, []string) { + if rel == nil { + return "", nil + } + holder := strings.TrimSpace(rel.Holder) + if holder == "" { + return "", nil + } + childTypeName := g.relationTypeName(shapeCfg, rel) + if childTypeName == "" { + return "", nil + } + typeExpr := "*" + childTypeName + if rel.Cardinality == state.Many { + typeExpr = "[]*" + childTypeName + } + tag := g.relationFieldTag(parent, rel) + return fmt.Sprintf("\t%s %s `%s`\n", holder, typeExpr, tag), nil +} + +func (g *ComponentCodegen) renderRelationSummaryField(shapeCfg *Config, rel *view.Relation, currentPackage string) (string, []string) { + child := g.semanticView(g.resolveRelationView(rel)) + if child == nil || child.Template == nil || child.Template.Summary == nil || child.Template.Summary.Schema == nil { + return "", nil + } + meta := child.Template.Summary + fieldName := state.StructFieldName(text.CaseFormatUpperCamel, meta.Name) + if strings.TrimSpace(fieldName) == "" { + return "", nil + } + typeName := strings.TrimSpace(meta.Schema.Name) + if typeName == "" { + return "", nil + } + typeExpr := "*" + typeName + tag := fmt.Sprintf(`json:",omitempty" yaml:",omitempty" sqlx:"-" typeName:"%s"`, typeName) + return fmt.Sprintf("\t%s %s `%s`\n", fieldName, typeExpr, tag), collectTypeImports(meta.Schema.Type(), currentPackage) +} + +func (g *ComponentCodegen) relationTypeName(shapeCfg *Config, rel *view.Relation) string { + if rel == nil { + return "" + } + for _, name := range []string{ + strings.TrimSpace(rel.Of.View.Name), + strings.TrimSpace(rel.Of.View.Reference.Ref), + strings.TrimSpace(rel.Name), + strings.TrimSpace(rel.Holder), + } { + if name == "" { + continue + } + if spec := g.typeSpec("view:" + strings.ToLower(strings.TrimSpace(name))); spec != nil && strings.TrimSpace(spec.TypeName) != "" { + return strings.TrimSpace(spec.TypeName) + } + } + if candidate := g.semanticView(g.resolveRelationView(rel)); candidate != nil { + if typeName := strings.TrimSpace(g.resourceViewTypeName(shapeCfg, candidate)); typeName != "" { + return typeName + } + } + if rel.Of.Schema != nil && strings.TrimSpace(rel.Of.Schema.Name) != "" { + return strings.TrimSpace(rel.Of.Schema.Name) + } + refNames := []string{ + strings.TrimSpace(rel.Of.View.Name), + strings.TrimSpace(rel.Of.View.Reference.Ref), + strings.TrimSpace(rel.Name), + } + for _, refName := range refNames { + if refName == "" { + continue + } + for _, candidate := range g.Resource.Views { + if candidate == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(candidate.Name), refName) || strings.EqualFold(strings.TrimSpace(candidate.Reference.Ref), refName) { + return g.resourceViewTypeName(shapeCfg, candidate) + } + } + } + return "" +} + +func (g *ComponentCodegen) generatedIndexColumn(aView *view.View) (*view.Column, string, reflect.Type, bool) { + if aView == nil { + return nil, "", nil, false + } + var candidate *view.Column + for _, column := range aView.Columns { + if column == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(column.FieldName()), "Id") || strings.EqualFold(strings.TrimSpace(column.Name), "ID") { + candidate = column + break + } + } + if candidate == nil { + return nil, "", nil, false + } + fieldName := strings.TrimSpace(candidate.FieldName()) + if fieldName == "" { + caseFormat := aView.CaseFormat + if !caseFormat.IsDefined() { + caseFormat = text.CaseFormatLowerUnderscore + } + fieldName = state.StructFieldName(caseFormat, candidate.Name) + } + rType := candidate.ColumnType() + if rType == nil { + if builtin, ok := builtinTypeByName(candidate.DataType); ok { + rType = builtin + } + } + if rType == nil { + return nil, "", nil, false + } + rType = g.normalizeColumnType(candidate, rType) + return candidate, fieldName, rType, true +} + +func (g *ComponentCodegen) columnFieldTag(aView *view.View, column *view.Column) string { + tag := strings.TrimSpace(column.Tag) + cleaned, _ := xreflect.RemoveTag(tag, "velty") + tag = strings.TrimSpace(cleaned) + groupable := column.Groupable + if aView != nil && aView.ColumnsConfig != nil { + if cfg := aView.ColumnsConfig[column.Name]; cfg != nil { + if cfg.Groupable != nil { + groupable = *cfg.Groupable + } + if cfg.Tag != nil { + configTag := strings.TrimSpace(strings.Trim(*cfg.Tag, ` `)) + if configTag != "" && !strings.Contains(tag, configTag) { + if tag != "" { + tag += " " + } + tag += configTag + } + } + } + } + if aView != nil && containsFold(aView.Exclude, column.Name) && !strings.Contains(tag, `internal:"true"`) { + if tag != "" { + tag += " " + } + tag += `internal:"true"` + } + if groupable && !strings.Contains(tag, `groupable:"`) { + if tag != "" { + tag += " " + } + tag += `groupable:"true"` + } + sqlxValue := strings.TrimSpace(column.Name) + if column.Codec != nil && strings.TrimSpace(column.DataType) != "" { + sqlxValue += ",type=" + strings.TrimSpace(column.DataType) + } + if sqlxValue != "" && !strings.Contains(tag, `sqlx:"`) { + if tag != "" { + tag += " " + } + tag += fmt.Sprintf(`sqlx:"%s"`, sqlxValue) + } + if g.resourceViewUsesVelty(aView) && !strings.Contains(tag, `velty:"`) { + caseFormat := aView.CaseFormat + if !caseFormat.IsDefined() { + caseFormat = text.CaseFormatLowerUnderscore + } + if tag != "" { + tag += " " + } + tag += fmt.Sprintf(`velty:"%s"`, generateVeltyTagValue(column.Name, caseFormat)) + } + return normalizeGeneratedTagOrder(strings.TrimSpace(tag)) +} + +func (g *ComponentCodegen) viewUsesVelty(aView *view.View) bool { + if aView == nil { + return false + } + switch aView.Mode { + case view.ModeExec: + return true + default: + return false + } +} + +func (g *ComponentCodegen) resourceViewUsesVelty(aView *view.View) bool { + if aView == nil || g == nil || g.Component == nil || !g.componentUsesVelty() || g.componentUsesHandler() { + return false + } + if g.viewUsesVelty(aView) { + return true + } + matches := func(value string) bool { + value = strings.TrimSpace(value) + if value == "" { + return false + } + return strings.EqualFold(strings.TrimSpace(aView.Name), value) || + strings.EqualFold(strings.TrimSpace(aView.Reference.Ref), value) + } + for _, input := range g.Component.Input { + if input == nil || input.In == nil || input.In.Kind != state.KindView { + continue + } + if matches(input.In.Name) || matches(input.Name) { + return true + } + } + return false +} + +func (g *ComponentCodegen) componentUsesMutableHelpers() bool { + if g == nil || g.Component == nil || !g.componentUsesVelty() || g.componentUsesHandler() { + return false + } + hasBody := false + hasView := false + for _, input := range g.Component.Input { + if input == nil || input.In == nil { + continue + } + switch input.In.Kind { + case state.KindRequestBody: + hasBody = true + case state.KindView: + hasView = true + } + } + return hasBody && hasView +} + +func (g *ComponentCodegen) componentUsesVelty() bool { + if g == nil { + return false + } + if g.componentUsesHandler() { + return false + } + if g.Resource != nil && g.Component != nil { + rootViewName := strings.TrimSpace(g.Component.RootView) + if rootViewName != "" { + if rootView, _ := g.Resource.View(rootViewName); rootView != nil { + return g.viewUsesVelty(rootView) + } + } + } + if g.Component == nil { + return false + } + method := strings.ToUpper(strings.TrimSpace(g.Component.Method)) + return method != "" && method != "GET" +} + +func (g *ComponentCodegen) componentUsesHandler() bool { + if g == nil || g.Component == nil { + return false + } + for _, route := range g.Component.ComponentRoutes { + if route != nil && strings.TrimSpace(route.Handler) != "" { + return true + } + } + if g.Resource != nil { + if rootViewName := strings.TrimSpace(g.Component.RootView); rootViewName != "" { + if rootView, _ := g.Resource.View(rootViewName); rootView != nil && rootView.Mode == view.ModeHandler { + return true + } + } + } + return false +} + +func normalizeGeneratedTagOrder(tag string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return tag + } + ordered := make([]string, 0, 4) + for _, key := range []string{"sqlx", "internal", "groupable", "velty", "json"} { + value := reflect.StructTag(tag).Get(key) + if value == "" { + continue + } + ordered = append(ordered, fmt.Sprintf(`%s:%q`, key, value)) + var updated string + updated, _ = xreflect.RemoveTag(tag, key) + tag = strings.TrimSpace(updated) + } + if tag != "" { + ordered = append(ordered, tag) + } + return strings.Join(ordered, " ") +} + +func (g *ComponentCodegen) relationFieldTag(parent *view.View, rel *view.Relation) string { + child := g.semanticView(g.resolveRelationView(rel)) + if child == nil { + return "" + } + tag := &viewtags.Tag{} + if metadata := buildViewMetadataTag(child, false, false); metadata != nil { + tag.View = metadata.View + } + if relTag := strings.TrimSpace(child.Tag); relTag != "" { + if tag.View == nil { + tag.View = &viewtags.View{} + } + tag.View.CustomTag = relTag + } + if parent != nil && parent.Cache != nil { + if tag.View == nil { + tag.View = &viewtags.View{} + } + tag.View.Cache = parent.Cache.Ref + } + if parent != nil && parent.Connector != nil && child.Connector != nil && child.Connector.Ref != parent.Connector.Ref { + if tag.View == nil { + tag.View = &viewtags.View{} + } + tag.View.Connector = child.Connector.Ref + } + tag.LinkOn = g.relationLinkTag(parent, child, rel) + if child.Template != nil { + tag.SQL = viewtags.NewViewSQL("", strings.TrimSpace(child.Template.SourceURL)) + } + return string(tag.UpdateTag(``)) +} + +func (g *ComponentCodegen) normalizeColumnType(column *view.Column, rType reflect.Type) reflect.Type { + if column == nil || rType == nil { + return rType + } + for rType.Kind() == reflect.Ptr { + if column.Nullable { + return rType + } + rType = rType.Elem() + } + if column.Nullable && rType.Kind() != reflect.Interface && rType.Kind() != reflect.Slice && rType.Kind() != reflect.Map { + return reflect.PtrTo(rType) + } + return rType +} + +type renderedField struct { + src string + imports []string +} + +func renderedFieldName(src string) string { + src = strings.TrimSpace(src) + if src == "" { + return "" + } + parts := strings.Fields(src) + if len(parts) == 0 { + return "" + } + return strings.TrimSpace(parts[0]) +} + +func (g *ComponentCodegen) renderScalarFallbackFields(aView *view.View, currentPackage string) []renderedField { + aView = g.semanticView(aView) + rType := g.resourceViewStructType(aView.Name) + rType = ensureCodegenStructType(rType) + if rType == nil && aView != nil { + rType = ensureCodegenStructType(aView.ComponentType()) + if rType == nil && aView.Schema != nil { + rType = ensureCodegenStructType(aView.Schema.Type()) + } + } + if rType == nil { + return nil + } + var result []renderedField + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if !field.IsExported() { + continue + } + if strings.TrimSpace(field.Tag.Get("view")) != "" || strings.TrimSpace(field.Tag.Get("on")) != "" { + continue + } + result = append(result, renderedField{ + src: fmt.Sprintf("\t%s %s `%s`\n", field.Name, goTypeString(field.Type), string(field.Tag)), + imports: collectTypeImports(field.Type, currentPackage), + }) + } + return result +} + +func (g *ComponentCodegen) resolveRelationView(rel *view.Relation) *view.View { + if rel == nil { + return nil + } + names := []string{ + strings.TrimSpace(rel.Of.View.Reference.Ref), + strings.TrimSpace(rel.Of.View.Name), + strings.TrimSpace(rel.Name), + } + for _, name := range names { + if name == "" || g == nil || g.Resource == nil { + continue + } + for _, candidate := range g.Resource.Views { + if candidate == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(candidate.Name), name) || strings.EqualFold(strings.TrimSpace(candidate.Reference.Ref), name) { + return candidate + } + } + } + return &rel.Of.View +} + +func (g *ComponentCodegen) semanticView(aView *view.View) *view.View { + if aView == nil { + return nil + } + merged := *aView + if merged.ColumnsConfig == nil { + merged.ColumnsConfig = map[string]*view.ColumnConfig{} + } + if g == nil || g.Resource == nil { + return &merged + } + for _, parent := range g.Resource.Views { + if parent == nil { + continue + } + for _, rel := range parent.With { + if rel == nil { + continue + } + if !g.matchesViewRef(&merged, rel) { + continue + } + g.mergeViewSemantics(&merged, &rel.Of.View) + } + } + return &merged +} + +func (g *ComponentCodegen) matchesViewRef(target *view.View, rel *view.Relation) bool { + if target == nil || rel == nil { + return false + } + candidates := []string{ + strings.TrimSpace(target.Name), + strings.TrimSpace(target.Reference.Ref), + } + refs := []string{ + strings.TrimSpace(rel.Of.View.Name), + strings.TrimSpace(rel.Of.View.Reference.Ref), + strings.TrimSpace(rel.Name), + } + for _, candidate := range candidates { + if candidate == "" { + continue + } + for _, ref := range refs { + if ref != "" && strings.EqualFold(candidate, ref) { + return true + } + } + } + return false +} + +func (g *ComponentCodegen) mergeViewSemantics(dst, src *view.View) { + if dst == nil || src == nil { + return + } + if len(dst.Columns) == 0 && len(src.Columns) > 0 { + dst.Columns = src.Columns + } + if len(dst.Exclude) == 0 && len(src.Exclude) > 0 { + dst.Exclude = append(dst.Exclude, src.Exclude...) + } + if dst.ColumnsConfig == nil { + dst.ColumnsConfig = map[string]*view.ColumnConfig{} + } + for key, cfg := range src.ColumnsConfig { + if _, ok := dst.ColumnsConfig[key]; !ok { + dst.ColumnsConfig[key] = cfg + } + } + if dst.Template == nil && src.Template != nil { + dst.Template = src.Template + } + if dst.Template != nil && src.Template != nil { + if strings.TrimSpace(dst.Template.Source) == "" { + dst.Template.Source = src.Template.Source + } + if strings.TrimSpace(dst.Template.SourceURL) == "" { + dst.Template.SourceURL = src.Template.SourceURL + } + if src.Template.Summary != nil { + if dst.Template.Summary == nil { + dst.Template.Summary = src.Template.Summary + } else { + if strings.TrimSpace(dst.Template.Summary.Name) == "" { + dst.Template.Summary.Name = src.Template.Summary.Name + } + if dst.Template.Summary.Kind == "" { + dst.Template.Summary.Kind = src.Template.Summary.Kind + } + if strings.TrimSpace(dst.Template.Summary.Source) == "" { + dst.Template.Summary.Source = src.Template.Summary.Source + } + if strings.TrimSpace(dst.Template.Summary.SourceURL) == "" { + dst.Template.Summary.SourceURL = src.Template.Summary.SourceURL + } + if src.Template.Summary.Schema != nil && (dst.Template.Summary.Schema == nil || dst.Template.Summary.Schema.Type() == nil) { + dst.Template.Summary.Schema = src.Template.Summary.Schema + } + } + } + } + if !isStableTableName(dst.Table) && isStableTableName(src.Table) { + dst.Table = src.Table + } + if dst.Schema == nil && src.Schema != nil { + dst.Schema = src.Schema + return + } + if dst.Schema != nil && src.Schema != nil { + if dst.Schema.Type() == nil && src.Schema.Type() != nil { + dst.Schema.SetType(src.Schema.Type()) + } + if strings.TrimSpace(dst.Schema.Name) == "" { + dst.Schema.Name = src.Schema.Name + } + } +} + +func (g *ComponentCodegen) relationLinkTag(parent, child *view.View, rel *view.Relation) viewtags.LinkOn { + if rel == nil { + return nil + } + result := make([]string, 0, len(rel.On)) + for i, parentLink := range rel.On { + if parentLink == nil { + continue + } + var childLink *view.Link + if i < len(rel.Of.On) { + childLink = rel.Of.On[i] + } + left := g.encodeRelationEndpoint(parent, parentLink) + right := g.encodeRelationEndpoint(child, childLink) + if left != "" && right != "" { + result = append(result, left+"="+right) + } + } + return result +} + +func (g *ComponentCodegen) encodeRelationEndpoint(owner *view.View, link *view.Link) string { + if link == nil { + return "" + } + column := stripNamespace(link.Column) + field := strings.TrimSpace(link.Field) + if field == "" { + caseFormat := text.CaseFormatLowerUnderscore + if owner != nil && owner.CaseFormat.IsDefined() { + caseFormat = owner.CaseFormat + } + field = state.StructFieldName(caseFormat, column) + } + if field == "" { + return column + } + if column == "" { + return field + } + return field + ":" + column +} + +func stripNamespace(value string) string { + value = strings.TrimSpace(value) + if idx := strings.LastIndex(value, "."); idx != -1 { + return strings.TrimSpace(value[idx+1:]) + } + return value +} + +func looksLikeSQL(value string) bool { + value = strings.TrimSpace(strings.ToUpper(value)) + return strings.Contains(value, "SELECT ") || strings.Contains(value, "\n") || strings.Contains(value, "(") +} + +func isStableTableName(value string) bool { + value = strings.TrimSpace(value) + if value == "" || looksLikeSQL(value) { + return false + } + for _, r := range value { + switch { + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case r == '_' || r == '.': + default: + return false + } + } + return true +} + +func containsFold(items []string, candidate string) bool { + candidate = strings.TrimSpace(candidate) + for _, item := range items { + if strings.EqualFold(strings.TrimSpace(item), candidate) { + return true + } + } + return false +} + +func generateVeltyTagValue(columnName string, caseFormat text.CaseFormat) string { + names := columnName + if fieldName := state.StructFieldName(caseFormat, columnName); fieldName != names { + names += "|" + fieldName + } + return "names=" + names +} + +func goTypeString(rType reflect.Type) string { + if rType == nil { + return "interface{}" + } + switch rType.Kind() { + case reflect.Ptr: + return "*" + goTypeString(rType.Elem()) + case reflect.Slice: + return "[]" + goTypeString(rType.Elem()) + case reflect.Array: + return fmt.Sprintf("[%d]%s", rType.Len(), goTypeString(rType.Elem())) + case reflect.Map: + return fmt.Sprintf("map[%s]%s", goTypeString(rType.Key()), goTypeString(rType.Elem())) + } + if rType.Name() != "" { + if pkg := strings.TrimSpace(rType.PkgPath()); pkg != "" { + prefix := filepath.Base(pkg) + if prefix != "" && prefix != "." { + return prefix + "." + rType.Name() + } + } + return rType.Name() + } + return rType.String() +} + +func (g *ComponentCodegen) resourceViewStructType(name any) reflect.Type { + if g == nil || g.Resource == nil { + return nil + } + viewName := strings.TrimSpace(asString(name)) + if viewName == "" { + return nil + } + for _, aView := range g.Resource.Views { + if aView == nil || !strings.EqualFold(strings.TrimSpace(aView.Name), viewName) { + continue + } + rType := aView.ComponentType() + if rType == nil && aView.Schema != nil { + rType = aView.Schema.Type() + } + rType = ensureCodegenStructType(rType) + if rebuilt := rebuildResourceViewStructType(rType, columnsFromView(aView), g.resourceViewUsesVelty(aView)); rebuilt != nil { + rType = rebuilt + } + if augmented := g.augmentResourceViewStructType(aView, rType); augmented != nil { + return augmented + } + return rType + } + return nil +} + +func ensureCodegenStructType(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + return rType +} + +func shouldRenderResourceViewType(rType reflect.Type, columns []columnDescriptor) bool { + rType = ensureCodegenStructType(rType) + if rType == nil { + return false + } + if resourceViewNeedsRebuild(rType, columns, false) { + return true + } + return rType.NumField() > len(columns) +} + +func rebuildResourceViewStructType(rType reflect.Type, columns []columnDescriptor, includeVelty bool) reflect.Type { + rType = ensureCodegenStructType(rType) + if rType == nil { + if len(columns) == 0 { + return nil + } + return reflect.StructOf(buildStructFields(columns, includeVelty)) + } + if !resourceViewNeedsRebuild(rType, columns, includeVelty) { + return rType + } + fields := buildStructFields(columns, includeVelty) + if len(fields) == 0 { + return rType + } + used := map[string]bool{} + for _, field := range fields { + used[field.Name] = true + } + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if isPlaceholderProjectionField(field) { + continue + } + if used[field.Name] { + continue + } + fields = append(fields, field) + used[field.Name] = true + } + return reflect.StructOf(fields) +} + +func (g *ComponentCodegen) augmentResourceViewStructType(aView *view.View, rType reflect.Type) reflect.Type { + rType = ensureCodegenStructType(rType) + if aView == nil || rType == nil { + return rType + } + fields := make([]reflect.StructField, 0, rType.NumField()+len(aView.With)+1) + used := map[string]bool{} + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + fields = append(fields, field) + used[field.Name] = true + } + if aView.SelfReference != nil { + if holder := strings.TrimSpace(aView.SelfReference.Holder); holder != "" && !used[holder] { + fields = append(fields, reflect.StructField{ + Name: holder, + Type: reflect.TypeOf([]interface{}{}), + Tag: `sqlx:"-"`, + }) + used[holder] = true + } + } + for _, rel := range aView.With { + if rel == nil { + continue + } + holder := strings.TrimSpace(rel.Holder) + if holder == "" || used[holder] { + continue + } + fieldType := g.relationHolderType(rel) + if fieldType == nil { + continue + } + tagParts := []string{} + if table := strings.TrimSpace(rel.Of.View.Table); table != "" { + tagParts = append(tagParts, fmt.Sprintf(`view:",table=%s"`, table)) + } else { + tagParts = append(tagParts, `view:""`) + } + tagParts = append(tagParts, `sqlx:"-"`) + fields = append(fields, reflect.StructField{ + Name: holder, + Type: fieldType, + Tag: reflect.StructTag(strings.Join(tagParts, " ")), + }) + used[holder] = true + } + if len(fields) == rType.NumField() { + return rType + } + return reflect.StructOf(fields) +} + +func (g *ComponentCodegen) relationHolderType(rel *view.Relation) reflect.Type { + if rel == nil { + return nil + } + childType := ensureCodegenStructType(rel.Of.View.ComponentType()) + if childType == nil && rel.Of.Schema != nil { + childType = ensureCodegenStructType(rel.Of.Schema.Type()) + } + if childType == nil && g != nil && g.Resource != nil { + refNames := []string{ + strings.TrimSpace(rel.Of.View.Name), + strings.TrimSpace(rel.Of.View.Reference.Ref), + strings.TrimSpace(rel.Name), + } + for _, refName := range refNames { + if refName == "" { + continue + } + childType = g.resourceViewStructType(refName) + if childType != nil { + break + } + } + } + if childType == nil { + return nil + } + childPtr := childType + if childPtr.Kind() != reflect.Ptr { + childPtr = reflect.PtrTo(childType) + } + if rel.Cardinality == state.One { + return childPtr + } + return reflect.SliceOf(childPtr) +} + +func resourceViewNeedsRebuild(rType reflect.Type, columns []columnDescriptor, includeVelty bool) bool { + rType = ensureCodegenStructType(rType) + if rType == nil { + return false + } + if len(columns) == 0 { + return false + } + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if isPlaceholderProjectionField(field) { + return true + } + if includeVelty && field.Tag.Get("sqlx") != "" && field.Tag.Get("sqlx") != "-" && field.Tag.Get("velty") == "" { + return true + } + if !includeVelty && field.Tag.Get("velty") != "" { + return true + } + } + return false +} + +func isPlaceholderProjectionField(field reflect.StructField) bool { + if tag := strings.TrimSpace(field.Tag.Get("view")); tag != "" { + return false + } + if tag := strings.TrimSpace(field.Tag.Get("sql")); tag != "" { + return false + } + sqlxTag := field.Tag.Get("sqlx") + sqlxName := sqlxTagName(sqlxTag) + if sqlxName == "" || sqlxName == "-" { + return false + } + name := strings.TrimSpace(field.Name) + if strings.HasPrefix(strings.ToLower(name), "col") && strings.HasPrefix(strings.ToLower(sqlxName), "col_") { + return true + } + return false +} + +func sqlxTagName(tag string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + if strings.HasPrefix(tag, "name=") { + tag = strings.TrimPrefix(tag, "name=") + } + if idx := strings.Index(tag, ","); idx != -1 { + tag = tag[:idx] + } + return strings.TrimSpace(tag) +} + +func collectTypeImports(rType reflect.Type, currentPackage string) []string { + seen := map[string]bool{} + var result []string + var visit func(reflect.Type) + visit = func(t reflect.Type) { + if t == nil { + return + } + for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice || t.Kind() == reflect.Array || t.Kind() == reflect.Map { + if t.Kind() == reflect.Map { + visit(t.Key()) + } + t = t.Elem() + if t == nil { + return + } + } + if pkg := strings.TrimSpace(t.PkgPath()); pkg != "" && pkg != currentPackage { + if !seen[pkg] { + seen[pkg] = true + result = append(result, pkg) + } + if t.Name() != "" { + return + } + } + if t.Kind() != reflect.Struct { + return + } + for i := 0; i < t.NumField(); i++ { + visit(t.Field(i).Type) + } + } + visit(rType) + sort.Strings(result) + return result +} + +func applyShapeDocViewTypeOverrides(root map[string]any, component *shapeload.Component) { + if root == nil || component == nil || len(component.TypeSpecs) == 0 { + return + } + resourceMap, _ := root["Resource"].(map[string]any) + if resourceMap == nil { + return + } + views, _ := resourceMap["Views"].([]any) + if len(views) == 0 { + return + } + for _, raw := range views { + viewMap, _ := raw.(map[string]any) + if viewMap == nil { + continue + } + name, _ := viewMap["Name"].(string) + name = strings.TrimSpace(name) + if name == "" { + continue + } + spec, ok := component.TypeSpecs["view:"+name] + if !ok || spec == nil || strings.TrimSpace(spec.TypeName) == "" { + continue + } + schemaMap, _ := viewMap["Schema"].(map[string]any) + if schemaMap == nil { + schemaMap = map[string]any{} + } + schemaMap["Name"] = strings.TrimSpace(spec.TypeName) + viewMap["Schema"] = schemaMap + } +} + +func collectViewTypeOverrides(component *shapeload.Component) map[string]string { + if component == nil || len(component.TypeSpecs) == 0 { + return nil + } + ret := map[string]string{} + for key, spec := range component.TypeSpecs { + if spec == nil || spec.Role != shapeload.TypeRoleView { + continue + } + typeName := strings.TrimSpace(spec.TypeName) + if typeName == "" { + continue + } + alias := strings.TrimSpace(spec.Alias) + if alias == "" && strings.HasPrefix(key, "view:") { + alias = strings.TrimPrefix(key, "view:") + } + if alias == "" { + continue + } + ret[strings.ToLower(alias)] = typeName + } + if len(ret) == 0 { + return nil + } + return ret +} + +func mergeImportPaths(groups ...[]string) []string { + var result []string + seen := map[string]bool{} + for _, group := range groups { + for _, item := range group { + item = strings.TrimSpace(item) + if item == "" || seen[item] { + continue + } + seen[item] = true + result = append(result, item) + } + } + return result +} + +func extractTypeDeclsAndImports(source string) ([]string, string, error) { + fset := token.NewFileSet() + fileNode, err := parser.ParseFile(fset, "", source, parser.ParseComments) + if err != nil { + return nil, "", err + } + var imports []string + for _, spec := range fileNode.Imports { + pathValue := strings.Trim(spec.Path.Value, `"`) + if spec.Name != nil && spec.Name.Name != "" && spec.Name.Name != "." && spec.Name.Name != "_" { + imports = append(imports, spec.Name.Name+` "`+pathValue+`"`) + continue + } + imports = append(imports, pathValue) + } + + var body bytes.Buffer + for _, decl := range fileNode.Decls { + if typeDecl, ok := decl.(*ast.GenDecl); ok && typeDecl.Tok == token.TYPE { + if err := format.Node(&body, fset, typeDecl); err != nil { + return nil, "", err + } + body.WriteString("\n\n") + } + } + return imports, strings.TrimSpace(body.String()), nil +} + +// renderOutputStruct writes the output struct definition. +// For reader components, it generates the standard pattern: +// +// type XxxOutput struct { +// response.Status `parameter:",kind=output,in=status" json:",omitempty"` +// Data []*XxxView `parameter:",kind=output,in=view" view:"xxx" sql:"uri=xxx/xxx.sql"` +// } +func (g *ComponentCodegen) renderOutputStruct(builder *strings.Builder, outputTypeName, viewTypeName, embedURI string, outputParams state.Parameters, outputType reflect.Type, mutableSupport *mutableComponentSupport) { + rootView := g.Component.RootView + rootViewMetadata := g.rootResourceView() + + builder.WriteString(fmt.Sprintf("type %s struct {\n", outputTypeName)) + + // Check if there's an explicit status parameter + hasStatus := false + hasViolations := false + for _, p := range outputParams { + if p == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(p.Name), "Violations") { + hasViolations = true + } + if p.In != nil && p.In.Name == "status" { + hasStatus = true + } + } + if !hasStatus && (len(outputParams) > 0 || g.shouldDefaultReaderOutput() || mutableSupport != nil) { + builder.WriteString("\tresponse.Status `parameter:\",kind=output,in=status\" json:\",omitempty\"`\n") + } + + for _, p := range outputParams { + if p == nil || p.In == nil { + continue + } + switch p.In.Name { + case "view": + cardinality := string(p.Schema.Cardinality) + typePrefix := "[]*" + if cardinality == string(state.One) { + typePrefix = "*" + } + fieldName := p.Name + if fieldName == "" || fieldName == "Output" { + fieldName = "Data" + } + tag := strings.TrimSpace(p.Tag) + if !strings.Contains(tag, `parameter:"`) { + tag = strings.TrimSpace(tag + ` parameter:",kind=output,in=view"`) + } + if !strings.Contains(tag, `view:"`) { + tag = strings.TrimSpace(tag + fmt.Sprintf(` view:"%s"`, rootView)) + } + if !strings.Contains(tag, `sql:"`) { + tag = strings.TrimSpace(tag + fmt.Sprintf(` sql:"uri=%s/%s.sql"`, embedURI, rootView)) + } + if rootViewMetadata != nil { + tag = mergeViewSQLTag(tag, rootViewMetadata) + } + if !strings.Contains(tag, `anonymous:"`) && p.Tag != "" && strings.Contains(p.Tag, "anonymous") { + tag += ` anonymous:"true"` + } + builder.WriteString(fmt.Sprintf("\t%s %s%s `%s`\n", fieldName, typePrefix, viewTypeName, tag)) + case "status": + builder.WriteString(fmt.Sprintf("\tresponse.Status `parameter:\",kind=output,in=status\" json:\",omitempty\"`\n")) + default: + // Other output parameters (meta, etc.) + typeName := "interface{}" + if p.Schema != nil && p.Schema.Name != "" { + typeName = p.Schema.Name + } + builder.WriteString(fmt.Sprintf("\t%s %s `parameter:\",kind=output,in=%s\"`\n", p.Name, typeName, p.In.Name)) + } + } + if mutableSupport != nil && !hasViolations { + builder.WriteString("\tViolations validator.Violations `json:\",omitempty\"`\n") + } + + builder.WriteString("}\n\n") + if mutableSupport != nil { + builder.WriteString(fmt.Sprintf("func (o *%s) setError(err error) {\n", outputTypeName)) + builder.WriteString("\to.Status.Message = err.Error()\n") + builder.WriteString("\to.Status.Status = \"error\"\n") + builder.WriteString("}\n\n") + } +} + +// resolveOutputWildcardTypes resolves output parameters with wildcard type `?` or empty +// schema to the view entity type. The legacy translator does this in updateParameterWithComponentOutputType. +func (g *ComponentCodegen) resolveOutputWildcardTypes(params state.Parameters, componentName string) { + viewType := componentName + "View" + rootView := g.rootResourceView() + for _, p := range params { + if p == nil || p.In == nil { + continue + } + if p.In.Kind != state.KindOutput { + continue + } + if p.Schema == nil { + p.Schema = &state.Schema{} + } + // Only view outputs default to the root view shape. Summary/status outputs + // must keep their own materialized schema types. + if p.In.Name == "view" && (p.Schema.Name == "" || p.Schema.DataType == "" || p.Schema.DataType == "?") { + p.Schema.Name = viewType + p.Schema.DataType = "*" + viewType + if p.Schema.Cardinality == "" { + p.Schema.Cardinality = state.Many + } + } + // Add view tag if missing + if p.In.Name == "view" && !strings.Contains(p.Tag, "view:") { + rootViewName := g.Component.RootView + p.Tag += fmt.Sprintf(` view:"%s"`, rootViewName) + } + if p.In.Name == "view" && rootView != nil { + p.Tag = mergeViewSQLTag(p.Tag, rootView) + } + } +} + +// defaultOutputParameters creates the default output parameters for a reader component: +// - Data: the main view data (anonymous, kind=output, in=view) +// - Status: response status (anonymous, kind=output, in=status) +// This mirrors internal/translator output.go ensureOutputParameters. +func (g *ComponentCodegen) defaultOutputParameters(componentName string) state.Parameters { + if !g.shouldDefaultReaderOutput() { + return nil + } + rootView := g.Component.RootView + viewType := componentName + "View" + + // Data parameter — references the root view + dataParam := &state.Parameter{ + Name: "Data", + In: state.NewOutputLocation("view"), + Tag: fmt.Sprintf(`anonymous:"true" view:"%s"`, rootView), + Schema: &state.Schema{ + Name: viewType, + DataType: "*" + viewType, + Cardinality: state.Many, + }, + } + + // Status parameter — response.Status + statusParam := &state.Parameter{ + Name: "Status", + In: state.NewOutputLocation("status"), + Tag: `anonymous:"true" json:",omitempty"`, + Schema: &state.Schema{DataType: "response.Status"}, + } + + return state.Parameters{dataParam, statusParam} +} + +func (g *ComponentCodegen) shouldDefaultReaderOutput() bool { + if g == nil || g.Resource == nil || g.Component == nil { + return true + } + rootViewName := strings.TrimSpace(g.Component.RootView) + if rootViewName == "" { + return true + } + rootView, _ := g.Resource.View(rootViewName) + if rootView == nil { + return true + } + switch rootView.Mode { + case view.ModeExec, view.ModeHandler: + return false + default: + return true + } +} + +func (g *ComponentCodegen) withRegister() bool { + if g.WithRegister == nil { + return true // default enabled + } + return *g.WithRegister +} + +func (g *ComponentCodegen) typeSpec(key string) *shapeload.TypeSpec { + if g == nil || g.Component == nil || g.Component.TypeSpecs == nil { + return nil + } + return g.Component.TypeSpecs[key] +} + +func (g *ComponentCodegen) inputTypeName(componentName string) string { + if spec := g.typeSpec("input"); spec != nil && strings.TrimSpace(spec.TypeName) != "" { + return strings.TrimSpace(spec.TypeName) + } + return componentName + "Input" +} + +func (g *ComponentCodegen) outputTypeName(componentName string) string { + if spec := g.typeSpec("output"); spec != nil && strings.TrimSpace(spec.TypeName) != "" { + return strings.TrimSpace(spec.TypeName) + } + return componentName + "Output" +} + +func (g *ComponentCodegen) rootViewTypeName(componentName string) string { + rootView := strings.TrimSpace(componentName) + if g.Component != nil && strings.TrimSpace(g.Component.RootView) != "" { + rootView = strings.TrimSpace(g.Component.RootView) + } + if spec := g.typeSpec("view:" + rootView); spec != nil && strings.TrimSpace(spec.TypeName) != "" { + return strings.TrimSpace(spec.TypeName) + } + return componentName + "View" +} + +func (g *ComponentCodegen) rootViewSourceURL() string { + if g == nil || g.Resource == nil { + return "" + } + rootView := "" + if g.Component != nil { + rootView = strings.TrimSpace(g.Component.RootView) + } + if rootView != "" { + if aView, _ := g.Resource.View(rootView); aView != nil && aView.Template != nil { + return strings.TrimSpace(aView.Template.SourceURL) + } + } + if len(g.Resource.Views) == 0 || g.Resource.Views[0] == nil || g.Resource.Views[0].Template == nil { + return "" + } + return strings.TrimSpace(g.Resource.Views[0].Template.SourceURL) +} + +func (g *ComponentCodegen) rootResourceView() *view.View { + if g == nil || g.Resource == nil { + return nil + } + rootView := "" + if g.Component != nil { + rootView = strings.TrimSpace(g.Component.RootView) + } + if rootView != "" { + if aView, _ := g.Resource.View(rootView); aView != nil { + return aView + } + } + if len(g.Resource.Views) == 0 { + return nil + } + return g.Resource.Views[0] +} + +func (g *ComponentCodegen) rootSummarySourceURL() string { + if g == nil || g.Resource == nil { + return "" + } + rootView := "" + var candidate *view.View + if g.Component != nil { + rootView = strings.TrimSpace(g.Component.RootView) + } + if rootView != "" { + if aView, _ := g.Resource.View(rootView); aView != nil { + candidate = aView + } + } + if candidate == nil { + if len(g.Resource.Views) == 0 || g.Resource.Views[0] == nil { + return "" + } + candidate = g.Resource.Views[0] + } + if candidate.Template == nil || candidate.Template.Summary == nil { + return "" + } + if sourceURL := strings.TrimSpace(candidate.Template.Summary.SourceURL); sourceURL != "" { + return sourceURL + } + return path.Join(text.CaseFormatUpperCamel.Format(g.componentName(), text.CaseFormatLowerUnderscore), strings.ToLower(candidate.Name)+"_summary.sql") +} + +func (g *ComponentCodegen) resolveOutputDestFileName(defaultName string) string { + if spec := g.typeSpec("output"); spec != nil { + if dest := strings.TrimSpace(spec.Dest); dest != "" { + return dest + } + } + if g.Component != nil && g.Component.Directives != nil { + if dest := strings.TrimSpace(g.Component.Directives.OutputDest); dest != "" { + return dest + } + } + return defaultName +} + +func (g *ComponentCodegen) resolveInputDestFileName(defaultName string) string { + if spec := g.typeSpec("input"); spec != nil { + if dest := strings.TrimSpace(spec.Dest); dest != "" { + return dest + } + } + if g.Component != nil && g.Component.Directives != nil { + if dest := strings.TrimSpace(g.Component.Directives.InputDest); dest != "" { + return dest + } + } + return defaultName +} + +func (g *ComponentCodegen) resolveViewDestFileName(defaultName string) string { + if root := strings.TrimSpace(g.Component.RootView); root != "" { + if spec := g.typeSpec("view:" + root); spec != nil { + if dest := strings.TrimSpace(spec.Dest); dest != "" { + return dest + } + } + } + if g.Component != nil && g.Component.TypeSpecs != nil { + for _, spec := range g.Component.TypeSpecs { + if spec == nil || spec.Role != shapeload.TypeRoleView { + continue + } + if dest := strings.TrimSpace(spec.Dest); dest != "" { + return dest + } + } + } + if g.Component != nil && g.Component.Directives != nil { + if dest := strings.TrimSpace(g.Component.Directives.Dest); dest != "" { + return dest + } + } + return defaultName +} + +func (g *ComponentCodegen) resolveRouterDestFileName(defaultName string) string { + if g.Component != nil && g.Component.Directives != nil { + if dest := strings.TrimSpace(g.Component.Directives.RouterDest); dest != "" { + return dest + } + } + return defaultName +} + +func (g *ComponentCodegen) writeSectionFile(dest, packageName string, imports []string, sections ...string) error { + var builder strings.Builder + builder.WriteString("package " + packageName + "\n\n") + if len(imports) > 0 { + builder.WriteString("import (\n") + for _, imp := range imports { + imp = strings.TrimSpace(imp) + if imp == "" { + continue + } + if strings.Contains(imp, " ") { + builder.WriteString("\t" + imp + "\n") + } else { + builder.WriteString("\t\"" + imp + "\"\n") + } + } + builder.WriteString(")\n\n") + } + builder.WriteString("// Code generated by datly transcribe. DO NOT EDIT.\n\n") + for _, section := range sections { + if strings.TrimSpace(section) == "" { + continue + } + builder.WriteString(section) + if !strings.HasSuffix(section, "\n\n") { + builder.WriteString("\n") + } + } + return writeAtomic(dest, []byte(dedupeGeneratedStructFields(builder.String())), 0o644) +} + +func (g *ComponentCodegen) appendSectionFile(dest string, sections ...string) error { + data, err := os.ReadFile(dest) + if err != nil { + return err + } + var builder strings.Builder + builder.Write(data) + if len(data) > 0 && !strings.HasSuffix(string(data), "\n") { + builder.WriteString("\n") + } + for _, section := range sections { + if strings.TrimSpace(section) == "" { + continue + } + builder.WriteString("\n") + builder.WriteString(section) + if !strings.HasSuffix(section, "\n") { + builder.WriteString("\n") + } + } + return writeAtomic(dest, []byte(dedupeGeneratedStructFields(builder.String())), 0o644) +} + +func dedupeGeneratedStructFields(source string) string { + lines := strings.Split(source, "\n") + var result []string + inStruct := false + fieldNames := map[string]bool{} + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "type ") && strings.HasSuffix(trimmed, "struct {") { + inStruct = true + fieldNames = map[string]bool{} + result = append(result, line) + continue + } + if inStruct { + if trimmed == "}" { + inStruct = false + fieldNames = nil + result = append(result, line) + continue + } + if name := generatedFieldName(line); name != "" { + if fieldNames[name] { + continue + } + fieldNames[name] = true + } + } + result = append(result, line) + } + return strings.Join(result, "\n") +} + +func generatedFieldName(line string) string { + trimmed := strings.TrimLeft(line, "\t ") + if trimmed == "" { + return "" + } + r, size := utf8.DecodeRuneInString(trimmed) + if r == utf8.RuneError || !unicode.IsUpper(r) { + return "" + } + rest := trimmed[size:] + var b strings.Builder + b.WriteRune(r) + for _, rr := range rest { + if rr == ' ' || rr == '\t' { + break + } + if unicode.IsLetter(rr) || unicode.IsDigit(rr) || rr == '_' { + b.WriteRune(rr) + continue + } + return "" + } + return b.String() +} + +func (g *ComponentCodegen) componentName() string { + name := "" + if g.Component != nil { + name = g.Component.RootView + } + if name == "" && g.Resource != nil && len(g.Resource.Views) > 0 { + name = g.Resource.Views[0].Name + } + if name == "" { + name = "Component" + } + return state.SanitizeTypeName(name) +} + +type namedHelperType struct { + TypeName string + Decl string + Imports []string +} + +func collectNamedHelperTypes(rType reflect.Type, currentPackage string, skip map[string]bool) []namedHelperType { + if rType == nil { + return nil + } + skip = cloneTypeNameSet(skip) + seen := map[string]bool{} + importSet := map[string]bool{} + var result []namedHelperType + var visitType func(reflect.Type) + var visitField func(reflect.StructField) + + visitField = func(field reflect.StructField) { + typeName := strings.TrimSpace(field.Tag.Get("typeName")) + baseType := unwrapAnonymousStructType(field.Type) + if typeName != "" && baseType != nil && baseType.Name() == "" && !skip[typeName] && !seen[typeName] { + seen[typeName] = true + skip[typeName] = true + imports := map[string]bool{} + for _, imp := range collectTypeImports(baseType, currentPackage) { + imports[imp] = true + importSet[imp] = true + } + result = append(result, namedHelperType{ + TypeName: typeName, + Decl: fmt.Sprintf("type %s struct {\n%s}\n\n", typeName, structFieldsSource(baseType)), + Imports: sortedImportSet(imports), + }) + } + visitType(field.Type) + } + + visitType = func(t reflect.Type) { + if t == nil { + return + } + for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice || t.Kind() == reflect.Array { + t = t.Elem() + if t == nil { + return + } + } + if t.Kind() == reflect.Map { + visitType(t.Key()) + visitType(t.Elem()) + return + } + if t.Kind() != reflect.Struct { + return + } + for i := 0; i < t.NumField(); i++ { + visitField(t.Field(i)) + } + } + + visitType(rType) + return result +} + +func helperImports(items []namedHelperType) []string { + imports := map[string]bool{} + for _, item := range items { + for _, imp := range item.Imports { + if strings.TrimSpace(imp) != "" { + imports[imp] = true + } + } + } + return sortedImportSet(imports) +} + +func cloneTypeNameSet(src map[string]bool) map[string]bool { + if len(src) == 0 { + return map[string]bool{} + } + ret := make(map[string]bool, len(src)) + for key, value := range src { + ret[key] = value + } + return ret +} + +func sortedImportSet(src map[string]bool) []string { + if len(src) == 0 { + return nil + } + ret := make([]string, 0, len(src)) + for key := range src { + ret = append(ret, key) + } + sort.Strings(ret) + return ret +} + +func unwrapAnonymousStructType(rType reflect.Type) reflect.Type { + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + if rType == nil { + return nil + } + } + if rType.Kind() != reflect.Struct { + return nil + } + return rType +} + +func (g *ComponentCodegen) outputUsesResponse(outputParams state.Parameters) bool { + for _, p := range outputParams { + if p == nil || p.In == nil { + continue + } + if p.In.Name == "status" { + return true + } + } + return len(outputParams) > 0 || g.shouldDefaultReaderOutput() +} + +func (g *ComponentCodegen) buildImports(includeRouter bool, includeResponse bool) []string { + var imports []string + needsReflect := g.withRegister() || g.WithContract + if needsReflect { + imports = append(imports, + "reflect", + ) + } + if g.withRegister() { + imports = append(imports, "github.com/viant/xdatly/types/core") + checksumPkg := "github.com/viant/xdatly/types/custom/checksum" + if g.PackagePath != "" { + if idx := strings.LastIndex(g.PackagePath, "/pkg/"); idx != -1 { + checksumPkg = g.PackagePath[:idx] + "/pkg/checksum" + } + } + checksumParent, _ := path.Split(checksumPkg) + if !strings.HasSuffix(strings.TrimSuffix(checksumParent, "/"), "dependency") { + checksumPkg = path.Join(checksumParent, "dependency", "checksum") + } + imports = append(imports, checksumPkg) + } + if includeResponse { + imports = append(imports, "github.com/viant/xdatly/handler/response") + } + if g.WithEmbed { + imports = append(imports, "embed") + } + if g.WithContract { + imports = append(imports, + "fmt", + "context", + "github.com/viant/datly/view", + "github.com/viant/datly/repository", + "github.com/viant/datly/repository/contract", + "github.com/viant/datly", + ) + } + if includeRouter { + imports = append(imports, "github.com/viant/xdatly") + } + return imports +} + +func (g *ComponentCodegen) buildRouterImports() []string { + return []string{"github.com/viant/xdatly"} +} + +func (g *ComponentCodegen) renderComponentHolder(builder *strings.Builder, componentName, inputTypeName, outputTypeName string, selectorHolders []codegenSelectorHolder) { + method := strings.TrimSpace(g.Component.Method) + if method == "" { + method = "GET" + } + uri := strings.TrimSpace(g.Component.URI) + if uri == "" { + uri = "/" + } + tag := fmt.Sprintf(`component:",path=%s,method=%s`, uri, method) + if connectorRef := strings.TrimSpace(g.rootConnectorRef()); connectorRef != "" { + tag += fmt.Sprintf(`,connector=%s`, connectorRef) + } + if marshaller := strings.TrimSpace(g.rootMarshaller()); marshaller != "" { + tag += fmt.Sprintf(`,marshaller=%s`, marshaller) + } + if handlerRef := strings.TrimSpace(g.rootHandlerRef()); handlerRef != "" { + tag += fmt.Sprintf(`,handler=%s`, handlerRef) + } + if viewTypeName := strings.TrimSpace(g.rootViewTypeName(componentName)); viewTypeName != "" { + tag += fmt.Sprintf(`,view=%s`, viewTypeName) + } + if sourceURL := strings.TrimSpace(g.rootViewSourceURL()); sourceURL != "" { + tag += fmt.Sprintf(`,source=%s`, sourceURL) + } + if summaryURL := strings.TrimSpace(g.rootSummarySourceURL()); summaryURL != "" { + tag += fmt.Sprintf(`,summary=%s`, summaryURL) + } + if reportTag := g.reportComponentTag(); reportTag != "" { + tag += reportTag + } + tag += `"` + builder.WriteString(fmt.Sprintf("type %sRouter struct {\n", componentName)) + builder.WriteString(fmt.Sprintf("\t%s xdatly.Component[%s, %s] `%s`\n", componentName, inputTypeName, outputTypeName, tag)) + for _, holder := range selectorHolders { + if holder.Type == nil || strings.TrimSpace(holder.QuerySelector) == "" || strings.TrimSpace(holder.FieldName) == "" { + continue + } + builder.WriteString(fmt.Sprintf("\t%s struct {\n", holder.FieldName)) + builder.WriteString(indentSource(structFieldsSource(holder.Type), "\t\t")) + builder.WriteString(fmt.Sprintf("\t} `querySelector:%q`\n", holder.QuerySelector)) + } + builder.WriteString("}\n\n") +} + +func (g *ComponentCodegen) renderDefineComponent(builder *strings.Builder, componentName, inputTypeName, outputTypeName string) { + method := strings.TrimSpace(g.Component.Method) + if method == "" { + method = "GET" + } + uri := strings.TrimSpace(g.Component.URI) + if uri == "" { + uri = "/" + } + connectorRef := strings.TrimSpace(g.rootConnectorRef()) + pathVar := componentName + "PathURI" + builder.WriteString(fmt.Sprintf("var %s = %q\n\n", pathVar, uri)) + builder.WriteString(fmt.Sprintf("func Define%sComponent(ctx context.Context, srv *datly.Service) error {\n", componentName)) + builder.WriteString("\taComponent, err := repository.NewComponent(\n") + builder.WriteString(fmt.Sprintf("\t\tcontract.NewPath(%q, %s),\n", method, pathVar)) + builder.WriteString("\t\trepository.WithResource(srv.Resource()),\n") + builder.WriteString("\t\trepository.WithContract(\n") + if g.WithEmbed { + builder.WriteString(fmt.Sprintf("\t\t\treflect.TypeOf(%s{}),\n", inputTypeName)) + builder.WriteString(fmt.Sprintf("\t\t\treflect.TypeOf(%s{}), &%sFS", outputTypeName, componentName)) + } else { + builder.WriteString(fmt.Sprintf("\t\t\treflect.TypeOf(%s{}),\n", inputTypeName)) + builder.WriteString(fmt.Sprintf("\t\t\treflect.TypeOf(%s{}), nil", outputTypeName)) + } + if connectorRef != "" { + builder.WriteString(fmt.Sprintf(`, view.WithConnectorRef(%q)`, connectorRef)) + } + builder.WriteString(")") + if reportOption := g.reportComponentOption(); reportOption != "" { + builder.WriteString(",\n") + builder.WriteString("\t\t") + builder.WriteString(reportOption) + } + builder.WriteString(")\n\n") + builder.WriteString("\tif err != nil {\n") + builder.WriteString(fmt.Sprintf("\t\treturn fmt.Errorf(\"failed to create %s component: %%w\", err)\n", componentName)) + builder.WriteString("\t}\n") + builder.WriteString("\tif err := srv.AddComponent(ctx, aComponent); err != nil {\n") + builder.WriteString(fmt.Sprintf("\t\treturn fmt.Errorf(\"failed to add %s component: %%w\", err)\n", componentName)) + builder.WriteString("\t}\n") + builder.WriteString("\treturn nil\n") + builder.WriteString("}\n\n") +} + +func (g *ComponentCodegen) reportComponentTag() string { + if g.Component == nil || g.Component.Report == nil || !g.Component.Report.Enabled { + return "" + } + report := g.Component.Report + tag := ",report=true" + if value := strings.TrimSpace(report.Input); value != "" { + tag += fmt.Sprintf(",reportInput=%s", value) + } + if value := strings.TrimSpace(report.Dimensions); value != "" { + tag += fmt.Sprintf(",reportDimensions=%s", value) + } + if value := strings.TrimSpace(report.Measures); value != "" { + tag += fmt.Sprintf(",reportMeasures=%s", value) + } + if value := strings.TrimSpace(report.Filters); value != "" { + tag += fmt.Sprintf(",reportFilters=%s", value) + } + if value := strings.TrimSpace(report.OrderBy); value != "" { + tag += fmt.Sprintf(",reportOrderBy=%s", value) + } + if value := strings.TrimSpace(report.Limit); value != "" { + tag += fmt.Sprintf(",reportLimit=%s", value) + } + if value := strings.TrimSpace(report.Offset); value != "" { + tag += fmt.Sprintf(",reportOffset=%s", value) + } + return tag +} + +func (g *ComponentCodegen) reportComponentOption() string { + if g.Component == nil || g.Component.Report == nil || !g.Component.Report.Enabled { + return "" + } + report := g.Component.Report + parts := []string{"Enabled: true"} + if value := strings.TrimSpace(report.Input); value != "" { + parts = append(parts, fmt.Sprintf("Input: %q", value)) + } + if value := strings.TrimSpace(report.Dimensions); value != "" { + parts = append(parts, fmt.Sprintf("Dimensions: %q", value)) + } + if value := strings.TrimSpace(report.Measures); value != "" { + parts = append(parts, fmt.Sprintf("Measures: %q", value)) + } + if value := strings.TrimSpace(report.Filters); value != "" { + parts = append(parts, fmt.Sprintf("Filters: %q", value)) + } + if value := strings.TrimSpace(report.OrderBy); value != "" { + parts = append(parts, fmt.Sprintf("OrderBy: %q", value)) + } + if value := strings.TrimSpace(report.Limit); value != "" { + parts = append(parts, fmt.Sprintf("Limit: %q", value)) + } + if value := strings.TrimSpace(report.Offset); value != "" { + parts = append(parts, fmt.Sprintf("Offset: %q", value)) + } + return fmt.Sprintf("repository.WithReport(&repository.Report{%s})", strings.Join(parts, ", ")) +} + +func (g *ComponentCodegen) rootConnectorRef() string { + if g.Resource == nil { + return "" + } + root := strings.TrimSpace(g.Component.RootView) + for _, aView := range g.Resource.Views { + if aView == nil || aView.Connector == nil { + continue + } + if root != "" && aView.Name == root { + if aView.Connector.Ref != "" { + return aView.Connector.Ref + } + return aView.Connector.Name + } + } + for _, aView := range g.Resource.Views { + if aView != nil && aView.Connector != nil { + if aView.Connector.Ref != "" { + return aView.Connector.Ref + } + return aView.Connector.Name + } + } + return "" +} + +func (g *ComponentCodegen) rootMarshaller() string { + if g == nil || g.Component == nil { + return "" + } + for _, route := range g.Component.ComponentRoutes { + if route == nil { + continue + } + if marshaller := strings.TrimSpace(route.Marshaller); marshaller != "" { + return marshaller + } + } + return "" +} + +func (g *ComponentCodegen) rootHandlerRef() string { + if g == nil || g.Component == nil { + return "" + } + for _, route := range g.Component.ComponentRoutes { + if route == nil { + continue + } + if handler := strings.TrimSpace(route.Handler); handler != "" { + return handler + } + } + return "" +} + +func structFieldsSource(rType reflect.Type) string { + if rType == nil { + return "" + } + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return "" + } + var b strings.Builder + for i := 0; i < rType.NumField(); i++ { + f := rType.Field(i) + if !f.IsExported() { + continue + } + typeExpr := sourceFieldTypeExpr(f) + b.WriteString("\t" + f.Name + " " + typeExpr) + if f.Tag != "" { + b.WriteString(" `" + string(f.Tag) + "`") + } + b.WriteString("\n") + } + return b.String() +} + +func indentSource(source, prefix string) string { + source = strings.TrimRight(source, "\n") + if source == "" { + return "" + } + lines := strings.Split(source, "\n") + for i, line := range lines { + lines[i] = prefix + line + } + return strings.Join(lines, "\n") + "\n" +} + +func sourceFieldTypeExpr(field reflect.StructField) string { + typeName := strings.TrimSpace(field.Tag.Get("typeName")) + if typeName == "" { + return field.Type.String() + } + return rewriteFieldTypeExpr(field.Type, typeName) +} + +func rewriteFieldTypeExpr(rType reflect.Type, typeName string) string { + if rType == nil { + return typeName + } + switch rType.Kind() { + case reflect.Ptr: + return "*" + rewriteFieldTypeExpr(rType.Elem(), typeName) + case reflect.Slice: + return "[]" + rewriteFieldTypeExpr(rType.Elem(), typeName) + case reflect.Array: + return fmt.Sprintf("[%d]%s", rType.Len(), rewriteFieldTypeExpr(rType.Elem(), typeName)) + case reflect.Map: + return "map[" + rType.Key().String() + "]" + rewriteFieldTypeExpr(rType.Elem(), typeName) + case reflect.Struct: + if rType.Name() == "" { + return typeName + } + } + return rType.String() +} + +func resourceToCodegenDoc(resource *view.Resource, typeCtx *typectx.Context) *shape.Document { + root := map[string]any{} + var views []any + for _, v := range resource.Views { + if v == nil { + continue + } + viewMap := map[string]any{ + "Name": v.Name, + "Table": v.Table, + "Mode": string(v.Mode), + } + if v.Schema != nil { + schema := map[string]any{} + if v.Schema.Name != "" { + schema["Name"] = v.Schema.Name + } + viewMap["Schema"] = schema + } + if len(v.Columns) > 0 { + var cols []any + for _, c := range v.Columns { + if c == nil { + continue + } + cols = append(cols, map[string]any{ + "Name": c.Name, + "DataType": c.DataType, + "Nullable": c.Nullable, + }) + } + viewMap["Columns"] = cols + } + views = append(views, viewMap) + } + root["Resource"] = map[string]any{"Views": views} + return &shape.Document{Root: root, TypeContext: typeCtx} +} diff --git a/repository/shape/xgen/codegen_contract_parity_test.go b/repository/shape/xgen/codegen_contract_parity_test.go new file mode 100644 index 000000000..7beec9585 --- /dev/null +++ b/repository/shape/xgen/codegen_contract_parity_test.go @@ -0,0 +1,134 @@ +package xgen + +import ( + "os" + "path/filepath" + "strings" + "testing" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + shapeload "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" +) + +func TestComponentCodegen_GeneratesDefineComponentParitySnippet(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "pkg", "dev", "wrapper") + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/dev/vendors/{vendorID}", + RootView: "Wrapper", + } + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "Wrapper", + Connector: &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Name: "dev"}}}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }, + }, + } + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "wrapper", + PackagePath: "github.com/acme/project/pkg/dev/wrapper", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + expectContains(t, generated, `var WrapperPathURI = "/v1/api/dev/vendors/{vendorID}"`) + expectContains(t, generated, `type WrapperRouter struct {`) + expectContains(t, generated, `Wrapper xdatly.Component[WrapperInput, WrapperOutput] `+"`"+`component:",path=/v1/api/dev/vendors/{vendorID},method=GET,connector=dev,view=WrapperView"`+"`") + expectContains(t, generated, `func DefineWrapperComponent(ctx context.Context, srv *datly.Service) error {`) + expectContains(t, generated, `contract.NewPath("GET", WrapperPathURI)`) + expectContains(t, generated, `repository.WithResource(srv.Resource())`) + expectContains(t, generated, `repository.WithContract(`) + expectContains(t, generated, `reflect.TypeOf(WrapperInput{})`) + expectContains(t, generated, `reflect.TypeOf(WrapperOutput{}), &WrapperFS, view.WithConnectorRef("dev"))`) +} + +func expectContains(t *testing.T, actual string, fragment string) { + t.Helper() + if !strings.Contains(actual, fragment) { + t.Fatalf("expected generated source to contain %q\nsource:\n%s", fragment, actual) + } +} + +func TestComponentCodegen_GeneratesSeparateRouterFile(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "pkg", "dev", "vendor") + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/dev/vendors", + RootView: "Vendor", + Directives: &dqlshape.Directives{ + RouterDest: "vendor_router.go", + }, + } + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "Vendor", + Connector: &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Name: "dev"}}}, + Columns: []*view.Column{{Name: "ID", DataType: "int"}}, + }, + }, + } + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "vendor", + PackagePath: "github.com/acme/project/pkg/dev/vendor", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + if filepath.Base(result.RouterFilePath) != "vendor_router.go" { + t.Fatalf("expected router file vendor_router.go, got %s", result.RouterFilePath) + } + if len(result.GeneratedFiles) != 2 { + t.Fatalf("expected 2 generated files, got %v", result.GeneratedFiles) + } + routerData, err := os.ReadFile(result.RouterFilePath) + if err != nil { + t.Fatalf("read router file: %v", err) + } + routerSource := string(routerData) + expectContains(t, routerSource, `type VendorRouter struct {`) + expectContains(t, routerSource, `Vendor xdatly.Component[VendorInput, VendorOutput] `+"`"+`component:",path=/v1/api/dev/vendors,method=GET,connector=dev,view=VendorView"`+"`") + outputData, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read primary file: %v", err) + } + if strings.Contains(string(outputData), "type VendorRouter struct") { + t.Fatalf("expected router declaration to be split out of primary output file:\n%s", string(outputData)) + } +} diff --git a/repository/shape/xgen/codegen_groupable_test.go b/repository/shape/xgen/codegen_groupable_test.go new file mode 100644 index 000000000..083dfda9c --- /dev/null +++ b/repository/shape/xgen/codegen_groupable_test.go @@ -0,0 +1,351 @@ +package xgen + +import ( + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type codegenMetaView struct { + PageCnt *int +} + +type codegenProductsMetaView struct { + VendorId *int +} + +type staleCodegenProductsMetaView struct { + VendorId string +} + +type codegenProductsView struct { + VendorId *int +} + +type codegenVendorView struct { + ID int + Products []*codegenProductsView + ProductsMeta *codegenProductsMetaView +} + +func TestComponentCodegen_ColumnFieldTag_EmitsGroupableTag(t *testing.T) { + groupable := true + codegen := &ComponentCodegen{} + aView := &view.View{ + ColumnsConfig: map[string]*view.ColumnConfig{ + "REGION": {Name: "REGION", Groupable: &groupable}, + }, + } + column := &view.Column{Name: "REGION", DataType: "string"} + + tag := codegen.columnFieldTag(aView, column) + assert.Contains(t, tag, `groupable:"true"`) + assert.Contains(t, tag, `sqlx:"REGION"`) +} + +func TestComponentCodegen_GeneratesSelectorHolderOutsideBusinessInput(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "pkg", "dev", "vendor") + component := &load.Component{ + Name: "Vendor", + Method: "GET", + URI: "/v1/api/dev/vendors-grouping", + RootView: "Vendor", + Report: &dqlshape.ReportDirective{ + Enabled: true, + Input: "VendorReportInput", + Dimensions: "Dims", + Measures: "Metrics", + Filters: "Predicates", + OrderBy: "Sort", + Limit: "Take", + Offset: "Skip", + }, + Directives: &dqlshape.Directives{ + InputDest: "vendor_input.go", + OutputDest: "vendor_output.go", + RouterDest: "vendor_router.go", + }, + Input: []*plan.State{ + {Parameter: state.Parameter{Name: "VendorIDs", In: state.NewQueryLocation("vendorIDs"), Schema: state.NewSchema(reflect.TypeOf([]int{}))}}, + {Parameter: state.Parameter{Name: "Fields", In: state.NewQueryLocation("_fields"), Schema: state.NewSchema(reflect.TypeOf([]string{}))}, QuerySelector: "vendor"}, + {Parameter: state.Parameter{Name: "OrderBy", In: state.NewQueryLocation("_orderby"), Schema: state.NewSchema(reflect.TypeOf(""))}, QuerySelector: "vendor"}, + }, + } + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "Vendor", + Groupable: true, + Selector: &view.Config{ + Constraints: &view.Constraints{ + OrderBy: true, + OrderByColumn: map[string]string{"accountId": "ACCOUNT_ID"}, + }, + }, + Template: &view.Template{SourceURL: "vendor/vendor.sql"}, + Columns: []*view.Column{{Name: "ACCOUNT_ID", DataType: "int"}}, + }, + }, + } + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "vendor", + PackagePath: "github.com/acme/project/pkg/dev/vendor", + } + + result, err := (&ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithContract: true, + }).Generate() + require.NoError(t, err) + + inputSource, err := os.ReadFile(result.InputFilePath) + require.NoError(t, err) + assert.Contains(t, string(inputSource), "type VendorInput struct {") + assert.Contains(t, string(inputSource), "VendorIDs []int") + assert.NotContains(t, string(inputSource), "Fields []string") + assert.NotContains(t, string(inputSource), "OrderBy string") + + routerSource, err := os.ReadFile(result.RouterFilePath) + require.NoError(t, err) + assert.Contains(t, string(routerSource), "ViewSelect struct {") + assert.Contains(t, string(routerSource), `querySelector:"vendor"`) + assert.Contains(t, string(routerSource), `report=true`) + assert.Contains(t, string(routerSource), `reportInput=VendorReportInput`) + assert.Contains(t, string(routerSource), `reportDimensions=Dims`) + assert.Contains(t, string(routerSource), `Fields []string `+"`"+`parameter:"`) + assert.Contains(t, string(routerSource), `in=_fields`) + assert.Contains(t, string(routerSource), `OrderBy string `+"`"+`parameter:"`) + assert.Contains(t, string(routerSource), `in=_orderby`) + + outputSource, err := os.ReadFile(result.OutputFilePath) + require.NoError(t, err) + assert.Contains(t, string(outputSource), `repository.WithReport(&repository.Report{Enabled: true, Input: "VendorReportInput", Dimensions: "Dims", Measures: "Metrics", Filters: "Predicates", OrderBy: "Sort", Limit: "Take", Offset: "Skip"})`) + assert.Contains(t, string(outputSource), `view:"Vendor,groupable=true`) + assert.Contains(t, string(outputSource), `selectorOrderBy=true`) + assert.Contains(t, string(outputSource), `selectorOrderByColumns={accountId:ACCOUNT_ID}`) +} + +func TestComponentCodegen_GeneratesSummaryMetadata(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "pkg", "dev", "vendor") + vendorSchema := state.NewSchema(reflect.TypeOf(codegenVendorView{})) + vendorSchema.Name = "VendorView" + vendorSchema.DataType = "*VendorView" + productsSchema := state.NewSchema(reflect.TypeOf(codegenProductsView{})) + productsSchema.Name = "ProductsView" + productsSchema.DataType = "*ProductsView" + metaSchema := state.NewSchema(reflect.TypeOf(codegenMetaView{})) + metaSchema.Name = "MetaView" + metaSchema.DataType = "*MetaView" + productsMetaSchema := state.NewSchema(reflect.TypeOf(codegenProductsMetaView{})) + productsMetaSchema.Name = "ProductsMetaView" + productsMetaSchema.DataType = "*ProductsMetaView" + component := &load.Component{ + Name: "Vendor", + Method: "GET", + URI: "/v1/api/dev/meta/vendors-format", + RootView: "vendor", + Directives: &dqlshape.Directives{ + OutputDest: "vendor.go", + }, + Output: []*plan.State{ + { + Parameter: state.Parameter{ + Name: "Meta", + In: state.NewOutputLocation("summary"), + Schema: metaSchema, + }, + }, + { + Parameter: state.Parameter{ + Name: "Data", + In: state.NewOutputLocation("view"), + Schema: &state.Schema{ + Name: "VendorView", + DataType: "*VendorView", + Cardinality: state.Many, + }, + }, + }, + }, + } + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "vendor", + Template: &view.Template{ + SourceURL: "vendor/vendor.sql", + Summary: &view.TemplateSummary{ + Name: "Meta", + SourceURL: "vendor/vendor_summary.sql", + Schema: metaSchema, + }, + }, + Schema: vendorSchema, + With: []*view.Relation{ + { + Holder: "Products", + Cardinality: state.Many, + Of: &view.ReferenceView{ + View: view.View{ + Name: "products", + Template: &view.Template{ + SourceURL: "vendor/products.sql", + Summary: &view.TemplateSummary{ + Name: "ProductsMeta", + SourceURL: "vendor/products_summary.sql", + Schema: productsMetaSchema, + }, + }, + Schema: productsSchema, + }, + On: []*view.Link{{Field: "VendorId", Column: "VENDOR_ID"}}, + }, + On: []*view.Link{{Field: "Id", Column: "ID"}}, + }, + }, + Columns: []*view.Column{{Name: "ID", DataType: "int"}}, + }, + { + Name: "products", + Template: &view.Template{ + SourceURL: "vendor/products.sql", + Summary: &view.TemplateSummary{ + Name: "ProductsMeta", + SourceURL: "vendor/products_summary.sql", + Schema: productsMetaSchema, + }, + }, + Schema: productsSchema, + Columns: []*view.Column{{Name: "VENDOR_ID", DataType: "int"}}, + }, + }, + } + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "vendor", + PackagePath: "github.com/acme/project/pkg/dev/vendor", + } + + result, err := (&ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithContract: true, + }).Generate() + require.NoError(t, err) + + outputSource, err := os.ReadFile(result.OutputFilePath) + require.NoError(t, err) + source := string(outputSource) + assert.Contains(t, source, `Meta MetaView`) + assert.Contains(t, source, `parameter:",kind=output,in=summary"`) + assert.Contains(t, source, `view:"vendor,summaryURI=vendor/vendor_summary.sql"`) + assert.Contains(t, source, `type MetaView struct {`) + assert.Contains(t, source, `ProductsMeta *ProductsMetaView`) + assert.Contains(t, source, `view:",summaryURI=vendor/products_summary.sql"`) + assert.Contains(t, source, `type ProductsMetaView struct {`) +} + +func TestComponentCodegen_PrefersStandaloneChildSummarySchemaOverStaleRelationCopy(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "pkg", "dev", "vendor") + vendorSchema := state.NewSchema(reflect.TypeOf(codegenVendorView{})) + vendorSchema.Name = "VendorView" + vendorSchema.DataType = "*VendorView" + productsSchema := state.NewSchema(reflect.TypeOf(codegenProductsView{})) + productsSchema.Name = "ProductsView" + productsSchema.DataType = "*ProductsView" + staleSummarySchema := state.NewSchema(reflect.TypeOf(staleCodegenProductsMetaView{})) + staleSummarySchema.Name = "ProductsMetaView" + staleSummarySchema.DataType = "*ProductsMetaView" + refinedSummarySchema := state.NewSchema(reflect.TypeOf(codegenProductsMetaView{})) + refinedSummarySchema.Name = "ProductsMetaView" + refinedSummarySchema.DataType = "*ProductsMetaView" + component := &load.Component{ + Name: "Vendor", + Method: "GET", + URI: "/v1/api/dev/meta/vendors-format", + RootView: "vendor", + } + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "vendor", + Template: &view.Template{SourceURL: "vendor/vendor.sql"}, + Schema: vendorSchema, + With: []*view.Relation{ + { + Holder: "Products", + Cardinality: state.Many, + Of: &view.ReferenceView{ + View: view.View{ + Name: "products", + Template: &view.Template{ + SourceURL: "vendor/products.sql", + Summary: &view.TemplateSummary{ + Name: "ProductsMeta", + SourceURL: "vendor/products_summary.sql", + Schema: staleSummarySchema, + }, + }, + Schema: productsSchema, + }, + On: []*view.Link{{Field: "VendorId", Column: "VENDOR_ID"}}, + }, + On: []*view.Link{{Field: "Id", Column: "ID"}}, + }, + }, + }, + { + Name: "products", + Template: &view.Template{ + SourceURL: "vendor/products.sql", + Summary: &view.TemplateSummary{ + Name: "ProductsMeta", + SourceURL: "vendor/products_summary.sql", + Schema: refinedSummarySchema, + }, + }, + Schema: productsSchema, + }, + }, + } + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "vendor", + PackagePath: "github.com/acme/project/pkg/dev/vendor", + } + + result, err := (&ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithContract: true, + }).Generate() + require.NoError(t, err) + + outputSource, err := os.ReadFile(result.OutputFilePath) + require.NoError(t, err) + source := string(outputSource) + assert.Contains(t, source, `type ProductsMetaView struct {`) + assert.Contains(t, source, `VendorId *int`) + assert.NotContains(t, source, `VendorId string`) +} diff --git a/repository/shape/xgen/codegen_imports_test.go b/repository/shape/xgen/codegen_imports_test.go new file mode 100644 index 000000000..648e0dc8e --- /dev/null +++ b/repository/shape/xgen/codegen_imports_test.go @@ -0,0 +1,20 @@ +package xgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestComponentCodegen_buildImports_ChecksumPathDefault(t *testing.T) { + g := &ComponentCodegen{} + imports := g.buildImports(false, false) + assert.Contains(t, imports, "github.com/viant/xdatly/types/custom/dependency/checksum") + assert.NotContains(t, imports, "github.com/viant/xdatly/types/custom/checksum") +} + +func TestComponentCodegen_buildImports_ChecksumPathFromPackagePath(t *testing.T) { + g := &ComponentCodegen{PackagePath: "github.com/acme/project/pkg/dev/vendor"} + imports := g.buildImports(false, false) + assert.Contains(t, imports, "github.com/acme/project/pkg/dependency/checksum") +} diff --git a/repository/shape/xgen/codegen_input_view_test.go b/repository/shape/xgen/codegen_input_view_test.go new file mode 100644 index 000000000..d492af469 --- /dev/null +++ b/repository/shape/xgen/codegen_input_view_test.go @@ -0,0 +1,870 @@ +package xgen + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + shapeload "github.com/viant/datly/repository/shape/load" + shapeplan "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/sqlx/types" +) + +func TestComponentCodegen_ViewInput_UsesResolvedViewType(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "update") + + component := &shapeload.Component{ + Method: "POST", + URI: "/v1/api/shape/dev/auth/products/", + RootView: "ProductUpdate", + Input: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "Jwt", In: state.NewHeaderLocation("Authorization"), Schema: &state.Schema{DataType: "string"}}}, + {Parameter: state.Parameter{Name: "Ids", In: state.NewBodyLocation("Ids"), Schema: &state.Schema{DataType: "[]int"}}}, + {Parameter: state.Parameter{Name: "Records", In: state.NewViewLocation("Records"), Schema: nil}}, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "ProductUpdate", + Mode: view.ModeExec, + }, + &view.View{ + Name: "Records", + Schema: &state.Schema{ + Name: "RecordsView", + DataType: "*RecordsView", + Cardinality: state.Many, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "STATUS", DataType: "int", Nullable: true}, + {Name: "IS_AUTH", DataType: "int", Nullable: true}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "update", + PackagePath: "github.com/acme/project/shape/dev/vendor/update", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if strings.Contains(generated, "Records []interface {}") { + t.Fatalf("expected typed Records input field, got interface slice:\n%s", generated) + } + if !strings.Contains(generated, "Records []") || !strings.Contains(generated, "RecordsView") { + t.Fatalf("expected Records view input field to reference RecordsView:\n%s", generated) + } + if !strings.Contains(generated, `Status *int `+"`"+`sqlx:"STATUS" velty:"names=STATUS|Status"`+"`") { + t.Fatalf("expected exec view input helper type to retain velty aliases:\n%s", generated) + } + if !strings.Contains(generated, `IsAuth *int `+"`"+`sqlx:"IS_AUTH" velty:"names=IS_AUTH|IsAuth"`+"`") { + t.Fatalf("expected exec view input helper type to retain SQL alias velty names:\n%s", generated) + } +} + +func TestComponentCodegen_InputSynthesizesRoutePathParams(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "team", "delete") + + component := &shapeload.Component{ + Method: "DELETE", + URI: "/v1/api/shape/dev/team/{teamID}", + RootView: "Team", + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "Team", + Mode: view.ModeExec, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "delete", + PackagePath: "github.com/acme/project/shape/dev/team/delete", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if !strings.Contains(generated, "TeamID string") { + t.Fatalf("expected implicit route path parameter in generated input:\n%s", generated) + } + if !strings.Contains(generated, `parameter:"teamID,kind=path,in=teamID"`) { + t.Fatalf("expected TeamID path parameter tag in generated input:\n%s", generated) + } + if !strings.Contains(generated, `velty:"names=TeamID|teamID"`) { + t.Fatalf("expected TeamID path parameter velty aliases in generated input:\n%s", generated) + } +} + +func TestComponentCodegen_ExportsLowercaseInputFieldNames(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "env") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/vendors-env/", + RootView: "Vendor", + Input: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "vendorIDs", In: state.NewQueryLocation("vendorIDs"), Schema: &state.Schema{DataType: "[]int", Cardinality: state.Many}}}, + {Parameter: state.Parameter{Name: "Vendor", In: state.NewConstLocation("Vendor"), Value: "VENDOR", Tag: `internal:"true"`, Schema: &state.Schema{DataType: "string", Cardinality: state.One}}}, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "Vendor", + Mode: view.ModeQuery, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "env", + PackagePath: "github.com/acme/project/shape/dev/vendor/env", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if !strings.Contains(generated, "VendorIDs ") { + t.Fatalf("expected exported generated field for lowercase query input:\n%s", generated) + } + if !strings.Contains(generated, `parameter:"vendorIDs,kind=query,in=vendorIDs"`) { + t.Fatalf("expected original query parameter name to be preserved in tag:\n%s", generated) + } +} + +func TestComponentCodegen_ReadView_OmitsVeltyTags(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "list") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/vendors/", + RootView: "Vendor", + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "Vendor", + Mode: view.ModeQuery, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "list", + PackagePath: "github.com/acme/project/shape/dev/vendor/list", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if strings.Contains(generated, `velty:"names=ID|Id"`) || strings.Contains(generated, `velty:"names=NAME|Name"`) { + t.Fatalf("expected read view fields to omit velty tags:\n%s", generated) + } +} + +func TestComponentCodegen_ReadInput_OmitsVeltyTags(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "list") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/vendors/", + RootView: "Vendor", + Input: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "VendorName", In: state.NewFormLocation("name"), Schema: &state.Schema{DataType: "string"}}}, + {Parameter: state.Parameter{Name: "Fields", In: state.NewQueryLocation("fields"), Schema: &state.Schema{DataType: "[]string", Cardinality: state.Many}}}, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "Vendor", + Mode: view.ModeQuery, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "list", + PackagePath: "github.com/acme/project/shape/dev/vendor/list", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if strings.Contains(generated, `velty:"names=`) { + t.Fatalf("expected read input fields to omit velty tags:\n%s", generated) + } +} + +func TestComponentCodegen_HandlerExec_OmitsVeltyTags(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "auth") + + component := &shapeload.Component{ + Method: "POST", + URI: "/v1/api/shape/dev/auth/vendor", + RootView: "Auth", + ComponentRoutes: []*shapeplan.ComponentRoute{{ + Name: "Auth", + RoutePath: "/v1/api/shape/dev/auth/vendor", + Method: "POST", + Handler: "github.com/acme/project/shape/dev/vendor/auth.Handler", + }}, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "Auth", + Mode: view.ModeHandler, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "auth", + PackagePath: "github.com/acme/project/shape/dev/vendor/auth", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if strings.Contains(generated, `velty:"names=ID|Id"`) || strings.Contains(generated, `velty:"names=NAME|Name"`) { + t.Fatalf("expected handler-generated fields to omit velty tags:\n%s", generated) + } + if !strings.Contains(generated, `handler=github.com/acme/project/shape/dev/vendor/auth.Handler`) { + t.Fatalf("expected generated router tag to include handler reference:\n%s", generated) + } +} + +func TestComponentCodegen_CodecBackedInput_UsesCodecResultType(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "user_acl") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/auth/user-acl", + RootView: "UserAcl", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Jwt", + In: state.NewHeaderLocation("Authorization"), + Schema: &state.Schema{DataType: "string", Cardinality: state.One}, + Output: &state.Codec{Name: "JwtClaim"}, + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "UserAcl", + Mode: view.ModeQuery, + Columns: []*view.Column{ + {Name: "UserID", DataType: "int"}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "user_acl", + PackagePath: "github.com/acme/project/shape/dev/vendor/user_acl", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if !strings.Contains(generated, `"github.com/viant/scy/auth/jwt"`) { + t.Fatalf("expected generated input to import jwt claims package:\n%s", generated) + } + if strings.Contains(generated, `"github.com/golang-jwt/jwt/v5"`) { + t.Fatalf("expected generated input to avoid nested jwt dependency import drift:\n%s", generated) + } + if !strings.Contains(generated, "Jwt *jwt.Claims") { + t.Fatalf("expected codec-backed Jwt field to use codec result type:\n%s", generated) + } + if !strings.Contains(generated, `dataType=string`) || + !strings.Contains(generated, `codec:"JwtClaim"`) { + t.Fatalf("expected Jwt field tag to preserve raw datatype and codec metadata:\n%s", generated) + } +} + +func TestComponentCodegen_ViewInput_OverridesStaleInlineSchemaType(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "team", "user_team") + + staleFieldType := reflect.StructOf([]reflect.StructField{ + {Name: "Id", Type: reflect.TypeOf(""), Tag: `json:"id,omitempty"`}, + {Name: "TeamMembers", Type: reflect.TypeOf(""), Tag: `json:"teamMembers,omitempty"`}, + {Name: "Name", Type: reflect.TypeOf(""), Tag: `json:"name,omitempty"`}, + }) + + component := &shapeload.Component{ + Method: "PUT", + URI: "/v1/api/shape/dev/teams", + RootView: "UserTeam", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "TeamStats", + In: state.NewViewLocation("TeamStats"), + Schema: state.NewSchema(reflect.SliceOf(staleFieldType)), + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "UserTeam", + Mode: view.ModeExec, + }, + &view.View{ + Name: "TeamStats", + Schema: &state.Schema{ + Name: "TeamStatsView", + DataType: "*TeamStatsView", + Cardinality: state.Many, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "TEAM_MEMBERS", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "user_team", + PackagePath: "github.com/acme/project/shape/dev/team/user_team", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if strings.Contains(generated, "TeamStats []struct") { + t.Fatalf("expected TeamStats view input field to use resolved TeamStatsView, got anonymous struct:\n%s", generated) + } + if !strings.Contains(generated, "TeamStats []") || !strings.Contains(generated, "TeamStatsView") { + t.Fatalf("expected TeamStats view input field to reference TeamStatsView:\n%s", generated) + } +} + +func TestComponentCodegen_ViewInput_ResolvesSnakeCaseViewName(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "team", "user_team") + + staleFieldType := reflect.StructOf([]reflect.StructField{ + {Name: "Id", Type: reflect.TypeOf(""), Tag: `json:"id,omitempty"`}, + }) + + component := &shapeload.Component{ + Method: "PUT", + URI: "/v1/api/shape/dev/teams", + RootView: "UserTeam", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "TeamStats", + In: state.NewViewLocation("TeamStats"), + Schema: state.NewSchema(reflect.SliceOf(staleFieldType)), + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{Name: "UserTeam", Mode: view.ModeExec}, + &view.View{ + Name: "team_stats", + Schema: &state.Schema{ + Name: "TeamStatsView", + DataType: "*TeamStatsView", + Cardinality: state.Many, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "user_team", + PackagePath: "github.com/acme/project/shape/dev/team/user_team", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if strings.Contains(generated, "UserTeamTeamStatsView") { + t.Fatalf("expected snake_case resource view to resolve to TeamStatsView, got stale nested helper:\n%s", generated) + } + if !strings.Contains(generated, "TeamStats []*TeamStatsView") { + t.Fatalf("expected TeamStats view input field to reference TeamStatsView:\n%s", generated) + } +} + +func TestComponentCodegen_ExecWithoutExplicitOutput_DoesNotSynthesizeReaderData(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "update") + + component := &shapeload.Component{ + Method: "POST", + URI: "/v1/api/shape/dev/auth/products/", + RootView: "ProductUpdate", + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "ProductUpdate", + Mode: view.ModeExec, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "update", + PackagePath: "github.com/acme/project/shape/dev/vendor/update", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if strings.Contains(generated, "parameter:\",kind=output,in=view\"") { + t.Fatalf("did not expect default reader output field for exec component:\n%s", generated) + } +} + +func TestComponentCodegen_ImportsNamedNonStructFieldTypes(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "user", "mysql_boolean") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/user-metadata", + RootView: "UserMetadata", + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "UserMetadata", + Schema: state.NewSchema(reflect.TypeOf([]struct { + ID int + IsEnabled *types.BitBool + }{})), + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "IS_ENABLED", DataType: "types.BitBool", Nullable: true}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "mysql_boolean", + PackagePath: "github.com/acme/project/shape/dev/user/mysql_boolean", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if !strings.Contains(generated, `"github.com/viant/sqlx/types"`) { + t.Fatalf("expected generated source to import github.com/viant/sqlx/types:\n%s", generated) + } +} + +func TestComponentCodegen_EmitsNamedInputHasHelper(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "user", "mysql_boolean") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/user-metadata", + RootView: "UserMetadata", + Input: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "Fields", In: state.NewQueryLocation("fields"), Schema: &state.Schema{DataType: "[]string"}}}, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "UserMetadata", + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "mysql_boolean", + PackagePath: "github.com/acme/project/shape/dev/user/mysql_boolean", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if !strings.Contains(generated, "type UserMetadataInputHas struct") { + t.Fatalf("expected named UserMetadataInputHas helper declaration:\n%s", generated) + } +} + +func TestComponentCodegen_ExecWithoutStatusDoesNotImportResponse(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "team", "delete") + + component := &shapeload.Component{ + Method: "DELETE", + URI: "/v1/api/shape/dev/team/{teamID}", + RootView: "Team", + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "Team", + Mode: view.ModeExec, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "delete", + PackagePath: "github.com/acme/project/shape/dev/team/delete", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if strings.Contains(generated, `"github.com/viant/xdatly/handler/response"`) { + t.Fatalf("did not expect response import for empty exec output:\n%s", generated) + } +} + +func TestComponentCodegen_MutableView_EmbedsHasMarker(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "events", "patch_basic_one") + + component := &shapeload.Component{ + Method: "PATCH", + URI: "/v1/api/shape/dev/basic/foos", + RootView: "foos", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Schema: &state.Schema{ + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.One, + }, + Tag: `anonymous:"true" typeName:"FoosView"`, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Schema: &state.Schema{ + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.One, + }, + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "foos", + Mode: view.ModeExec, + Template: &view.Template{ + Source: "#set($_ = $Foos(body/).Required())", + }, + Schema: &state.Schema{Name: "FoosView", Cardinality: state.Many}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + {Name: "QUANTITY", DataType: "int", Nullable: true}, + }, + }, + &view.View{ + Name: "CurFoos", + Mode: view.ModeQuery, + Schema: &state.Schema{Name: "FoosView", Cardinality: state.Many}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + {Name: "QUANTITY", DataType: "int", Nullable: true}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "patch_basic_one", + PackagePath: "github.com/acme/project/shape/dev/events/patch_basic_one", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + + if !strings.Contains(generated, "type FoosViewHas struct") { + t.Fatalf("expected mutable helper type declaration:\n%s", generated) + } + if !strings.Contains(generated, `Has *FoosViewHas `+"`"+`setMarker:"true" format:"-" sqlx:"-" diff:"-" json:"-" typeName:"FoosViewHas"`+"`") { + t.Fatalf("expected mutable view to embed Has marker:\n%s", generated) + } + if !strings.Contains(generated, `Foos *FoosView `+"`"+`parameter:",kind=body" typeName:"FoosView" anonymous:"true"`+"`") { + t.Fatalf("expected mutable body input to stay pointer typed:\n%s", generated) + } +} diff --git a/repository/shape/xgen/codegen_mutable_body_test.go b/repository/shape/xgen/codegen_mutable_body_test.go new file mode 100644 index 000000000..787b37406 --- /dev/null +++ b/repository/shape/xgen/codegen_mutable_body_test.go @@ -0,0 +1,195 @@ +package xgen + +import ( + "reflect" + "strings" + "testing" + + shapeload "github.com/viant/datly/repository/shape/load" + shapeplan "github.com/viant/datly/repository/shape/plan" + shapeast "github.com/viant/datly/repository/shape/velty/ast" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type mutableBodyFoos struct { + Id *int + Name *string + FoosPerformance []*mutableBodyFoosPerformance `view:",table=FOOS_PERFORMANCE" on:"Id:ID=FooId:FOO_ID"` +} + +type mutableBodyFoosPerformance struct { + Id *int + FooId *int + Name *string +} + +func TestComponentCodegen_BuildMutableVeltyBlock_PatchOne(t *testing.T) { + codegen, inputType, support := newMutableBodyFixture("PATCH", false) + actual := renderMutableBlock(t, codegen, inputType, support) + for _, fragment := range []string{ + `$sequencer.Allocate("FOOS", $Foos, "Id")`, + `#set($CurFoosById = $CurFoos.IndexBy("Id"))`, + `#if($Foos)`, + `#if($CurFoosById.HasKey($Foos.Id) == true)`, + `$sql.Update($Foos, "FOOS");`, + `$sql.Insert($Foos, "FOOS");`, + } { + if !strings.Contains(actual, fragment) { + t.Fatalf("expected fragment %q in generated body:\n%s", fragment, actual) + } + } +} + +func TestComponentCodegen_BuildMutableVeltyBlock_PatchMany(t *testing.T) { + codegen, inputType, support := newMutableBodyFixture("PATCH", true) + actual := renderMutableBlock(t, codegen, inputType, support) + for _, fragment := range []string{ + `$sequencer.Allocate("FOOS", $Foos, "Id")`, + `#set($CurFoosById = $CurFoos.IndexBy("Id"))`, + `#foreach($RecFoos in $Foos)`, + `#if($CurFoosById.HasKey($RecFoos.Id) == true)`, + `$sql.Update($RecFoos, "FOOS");`, + `$sql.Insert($RecFoos, "FOOS");`, + } { + if !strings.Contains(actual, fragment) { + t.Fatalf("expected fragment %q in generated body:\n%s", fragment, actual) + } + } +} + +func TestComponentCodegen_BuildMutableVeltyBlock_PutOne(t *testing.T) { + codegen, inputType, support := newMutableBodyFixture("PUT", false) + actual := renderMutableBlock(t, codegen, inputType, support) + if strings.Contains(actual, `$sequencer.Allocate(`) { + t.Fatalf("did not expect sequence allocation in PUT body:\n%s", actual) + } + if strings.Contains(actual, `$sql.Insert($Foos, "FOOS");`) { + t.Fatalf("did not expect insert branch in PUT body:\n%s", actual) + } + for _, fragment := range []string{ + `#set($CurFoosById = $CurFoos.IndexBy("Id"))`, + `#if($Foos)`, + `#if($CurFoosById.HasKey($Foos.Id) == true)`, + `$sql.Update($Foos, "FOOS");`, + } { + if !strings.Contains(actual, fragment) { + t.Fatalf("expected fragment %q in generated body:\n%s", fragment, actual) + } + } +} + +func TestComponentCodegen_BuildMutableVeltyBlock_PostMany(t *testing.T) { + codegen, inputType, support := newMutableBodyFixture("POST", true) + actual := renderMutableBlock(t, codegen, inputType, support) + if strings.Contains(actual, `HasKey`) || strings.Contains(actual, `$sql.Update(`) { + t.Fatalf("did not expect update logic in POST body:\n%s", actual) + } + for _, fragment := range []string{ + `$sequencer.Allocate("FOOS", $Foos, "Id")`, + `#set($CurFoosById = $CurFoos.IndexBy("Id"))`, + `#foreach($RecFoos in $Foos)`, + `$sql.Insert($RecFoos, "FOOS");`, + } { + if !strings.Contains(actual, fragment) { + t.Fatalf("expected fragment %q in generated body:\n%s", fragment, actual) + } + } +} + +func TestComponentCodegen_BuildMutableVeltyBlock_PatchManyMany(t *testing.T) { + codegen, inputType, support := newMutableBodyFixture("PATCH", true) + actual := renderMutableBlock(t, codegen, inputType, support) + for _, fragment := range []string{ + `$sequencer.Allocate("FOOS_PERFORMANCE", $Foos, "FoosPerformance/Id")`, + `#foreach($RecFoosPerformance in $RecFoos.FoosPerformance)`, + `#set($RecFoosPerformance.FooId = $RecFoos.Id)`, + `#if($CurFoosPerformanceById.HasKey($RecFoosPerformance.Id) == true)`, + `$sql.Update($RecFoosPerformance, "FOOS_PERFORMANCE");`, + `$sql.Insert($RecFoosPerformance, "FOOS_PERFORMANCE");`, + } { + if !strings.Contains(actual, fragment) { + t.Fatalf("expected fragment %q in generated body:\n%s", fragment, actual) + } + } +} + +func TestComponentCodegen_BuildMutableVeltyBlock_PutOneMany(t *testing.T) { + codegen, inputType, support := newMutableBodyFixture("PUT", false) + actual := renderMutableBlock(t, codegen, inputType, support) + if strings.Contains(actual, `$sql.Insert($RecFoosPerformance, "FOOS_PERFORMANCE");`) { + t.Fatalf("did not expect child insert branch in PUT body:\n%s", actual) + } + for _, fragment := range []string{ + `#foreach($RecFoosPerformance in $Foos.FoosPerformance)`, + `#set($RecFoosPerformance.FooId = $Foos.Id)`, + `#if($CurFoosPerformanceById.HasKey($RecFoosPerformance.Id) == true)`, + `$sql.Update($RecFoosPerformance, "FOOS_PERFORMANCE");`, + } { + if !strings.Contains(actual, fragment) { + t.Fatalf("expected fragment %q in generated body:\n%s", fragment, actual) + } + } +} + +func newMutableBodyFixture(method string, many bool) (*ComponentCodegen, reflect.Type, *mutableComponentSupport) { + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{Name: "CurFoos", Template: &view.Template{Source: "SELECT * FROM FOOS WHERE ID IN (?)"}}, + &view.View{Name: "CurFoosPerformance", Template: &view.Template{Source: "SELECT * FROM FOOS_PERFORMANCE WHERE ID IN (?)"}}, + ) + codegen := &ComponentCodegen{Component: &shapeload.Component{ + Method: method, + Input: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "CurFoos", In: state.NewViewLocation("CurFoos"), Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`}}, + {Parameter: state.Parameter{Name: "CurFoosPerformance", In: state.NewViewLocation("CurFoosPerformance"), Tag: `view:"CurFoosPerformance" sql:"uri=foos/cur_foos_performance.sql"`}}, + }, + }, Resource: resource} + bodyType := reflect.TypeOf(&mutableBodyFoos{}) + if many { + bodyType = reflect.TypeOf([]*mutableBodyFoos{}) + } + inputType := reflect.StructOf([]reflect.StructField{ + {Name: "Foos", Type: bodyType}, + {Name: "CurFoos", Type: reflect.TypeOf([]*mutableBodyFoos{})}, + }) + support := &mutableComponentSupport{ + BodyFieldName: "Foos", + Helpers: []mutableIndexHelper{ + { + ViewParamName: "CurFoos", + ViewFieldName: "CurFoos", + MapFieldName: "CurFoosById", + KeyFieldName: "Id", + KeyFieldType: "int", + TypeName: "Foos", + ItemTypeExpr: "*xgen.mutableBodyFoos", + ItemIsPointer: true, + }, + { + ViewParamName: "CurFoosPerformance", + ViewFieldName: "CurFoosPerformance", + MapFieldName: "CurFoosPerformanceById", + KeyFieldName: "Id", + KeyFieldType: "int", + TypeName: "FoosPerformance", + ItemTypeExpr: "*xgen.mutableBodyFoosPerformance", + ItemIsPointer: true, + }, + }, + } + return codegen, inputType, support +} + +func renderMutableBlock(t *testing.T, codegen *ComponentCodegen, inputType reflect.Type, support *mutableComponentSupport) string { + t.Helper() + block, err := codegen.buildMutableVeltyBlock(inputType, support) + if err != nil { + t.Fatalf("build mutable body: %v", err) + } + builder := shapeast.NewBuilder(shapeast.Options{Lang: shapeast.LangVelty}) + if err = block.Generate(builder); err != nil { + t.Fatalf("generate mutable body: %v", err) + } + return strings.TrimSpace(builder.String()) +} diff --git a/repository/shape/xgen/codegen_mutable_helpers_test.go b/repository/shape/xgen/codegen_mutable_helpers_test.go new file mode 100644 index 000000000..5368d4feb --- /dev/null +++ b/repository/shape/xgen/codegen_mutable_helpers_test.go @@ -0,0 +1,730 @@ +package xgen + +import ( + "os" + "path/filepath" + "reflect" + "regexp" + "strings" + "testing" + + shapeload "github.com/viant/datly/repository/shape/load" + shapeplan "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type BasicFoos struct { + Id *int + Name *string +} + +type Foos struct { + Id *int + Name *string + FoosPerformance []*FoosPerformance `view:",table=FOOS_PERFORMANCE" on:"Id:ID=FooId:FOO_ID"` +} + +type FoosPerformance struct { + Id *int + FooId *int + Name *string +} + +func TestComponentCodegen_MutableComponent_GeneratesPatchHelpers(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "generate_patch_basic_one") + + component := &shapeload.Component{ + Method: "PATCH", + URI: "/v1/api/dev/basic/foos", + RootView: "Foos", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.One}, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + Tag: `codec:"structql,uri=foos/cur_foos_id.sql"`, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`, + Schema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + }, + }, + }, + Output: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewOutputLocation("body"), + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.One}, + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "Foos", + Mode: view.ModeExec, + Schema: &state.Schema{ + Name: "Foos", + DataType: "*Foos", + Cardinality: state.One, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + }, + }, + &view.View{ + Name: "CurFoos", + Mode: view.ModeQuery, + Schema: &state.Schema{ + Name: "Foos", + DataType: "*Foos", + Cardinality: state.Many, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "generate_patch_basic_one", + PackagePath: "github.com/acme/project/shape/dev/generate_patch_basic_one", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + + inputSource := mustReadCodegenFile(t, result.InputFilePath) + if !strings.Contains(inputSource, `CurFoosById map[int]Foos`) { + t.Fatalf("expected generated input to include indexed helper map:\n%s", inputSource) + } + if !strings.Contains(inputSource, "CurFoosId *struct {") || !strings.Contains(inputSource, "Values []int") { + t.Fatalf("expected generated input to preserve helper ids struct type:\n%s", inputSource) + } + + initSource := mustReadCodegenFile(t, filepath.Join(packageDir, "input_init.go")) + if !strings.Contains(initSource, `i.CurFoosById = make(map[int]Foos, len(i.CurFoos))`) { + t.Fatalf("expected generated init helper to allocate CurFoosById:\n%s", initSource) + } + if !strings.Contains(initSource, `i.CurFoosById[item.Id] = item`) { + t.Fatalf("expected generated init helper to populate CurFoosById:\n%s", initSource) + } + + validateSource := mustReadCodegenFile(t, filepath.Join(packageDir, "input_validate.go")) + if !strings.Contains(validateSource, `_, err := aValidator.Validate(ctx, value, append(options, validator.WithValidation(validation))...)`) { + t.Fatalf("expected generated validate helper to call validator service:\n%s", validateSource) + } + if !strings.Contains(validateSource, `case Foos:`) || !strings.Contains(validateSource, `_, ok := i.CurFoosById[actual.Id]`) { + t.Fatalf("expected generated validate helper to use CurFoosById marker provider:\n%s", validateSource) + } + + outputSource := mustReadCodegenFile(t, result.OutputFilePath) + if !strings.Contains(outputSource, `response.Status `+"`"+`parameter:",kind=output,in=status" json:",omitempty"`+"`") { + t.Fatalf("expected mutable output to embed response status:\n%s", outputSource) + } + if !strings.Contains(outputSource, `Violations validator.Violations `+"`"+`json:",omitempty"`+"`") { + t.Fatalf("expected mutable output to include validation violations:\n%s", outputSource) + } + if !strings.Contains(outputSource, `func (o *FoosOutput) setError(err error) {`) { + t.Fatalf("expected mutable output to include setError helper:\n%s", outputSource) + } + + if result.VeltyFilePath == "" { + t.Fatalf("expected mutable component to emit velty artifact path") + } + veltySource := mustReadCodegenFile(t, result.VeltyFilePath) + for _, fragment := range []string{ + `$sequencer.Allocate("FOOS", $Unsafe.Foos, "Id")`, + `#set($CurFoosById = $Unsafe.CurFoos.IndexBy("Id"))`, + `$sql.Update($Unsafe.Foos, "FOOS");`, + `$sql.Insert($Unsafe.Foos, "FOOS");`, + } { + if !strings.Contains(veltySource, fragment) { + t.Fatalf("expected generated velty body to include %q:\n%s", fragment, veltySource) + } + } + foundVelty := false + for _, generated := range result.GeneratedFiles { + if generated == result.VeltyFilePath { + foundVelty = true + break + } + } + if !foundVelty { + t.Fatalf("expected generated files to include velty artifact: %v", result.GeneratedFiles) + } + curIDsPath := filepath.Join(packageDir, "foos", "cur_foos_id.sql") + if _, err := os.Stat(curIDsPath); err != nil { + t.Fatalf("expected generated current-ids SQL at %s, files=%v", curIDsPath, result.GeneratedFiles) + } + curIDsSQL := mustReadCodegenFile(t, curIDsPath) + if !strings.Contains(curIDsSQL, `SELECT ARRAY_AGG(Id) AS Values`) { + t.Fatalf("expected generated current-ids SQL:\n%s", curIDsSQL) + } + curViewPath := filepath.Join(packageDir, "foos", "cur_foos.sql") + if _, err := os.Stat(curViewPath); err != nil { + t.Fatalf("expected generated current-view SQL at %s, files=%v", curViewPath, result.GeneratedFiles) + } + curViewSQL := mustReadCodegenFile(t, curViewPath) + if !strings.Contains(curViewSQL, `SELECT * FROM FOOS`) { + t.Fatalf("expected generated current-view SQL:\n%s", curViewSQL) + } +} + +func TestComponentCodegen_MutableComponent_DSQLParity_BasicOne(t *testing.T) { + result, packageDir := generateMutableFixture(t, mutableFixtureSpec{ + packageName: "generate_patch_basic_one", + method: "PATCH", + uri: "/v1/api/dev/basic/foos", + bodySchema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.One}, + outputSchema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.One}, + views: []*view.View{ + { + Name: "Foos", + Mode: view.ModeExec, + Connector: &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Reference: shared.Reference{Ref: "dev"}}}}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&BasicFoos{})) + s.Name, s.DataType, s.Cardinality = "Foos", "*Foos", state.One + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + { + Name: "CurFoos", + Mode: view.ModeQuery, + Template: &view.Template{Source: "SELECT * FROM FOOS\nWHERE $criteria.In(\"ID\", $Unsafe.CurFoosId.Values)"}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&BasicFoos{})) + s.Name, s.DataType, s.Cardinality = "Foos", "*Foos", state.Many + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + }, + extraInput: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + Tag: `codec:"structql,uri=foos/cur_foos_id.sql"`, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`, + Schema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + }, + }, + }, + }) + assertMutableDSQLParity(t, result.VeltyFilePath, "/Users/awitas/go/src/github.com/viant/datly/e2e/local/dql/generate_patch_basic_one/patch_basic_one.sql") + assertMutableSQLFileParity(t, filepath.Join(packageDir, "foos", "cur_foos_id.sql"), "/Users/awitas/go/src/github.com/viant/datly/e2e/local/pkg/dev/generate_patch_basic_one/foos/cur_foos_id.sql") + assertMutableSQLFileParity(t, filepath.Join(packageDir, "foos", "cur_foos.sql"), "/Users/awitas/go/src/github.com/viant/datly/e2e/local/pkg/dev/generate_patch_basic_one/foos/cur_foos.sql") +} + +func TestComponentCodegen_MutableComponent_DSQLParity_BasicMany(t *testing.T) { + result, packageDir := generateMutableFixture(t, mutableFixtureSpec{ + packageName: "generate_patch_basic_many", + method: "PATCH", + uri: "/v1/api/dev/basic/foos-many", + bodySchema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + outputSchema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + views: []*view.View{ + { + Name: "Foos", + Mode: view.ModeExec, + Connector: &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Reference: shared.Reference{Ref: "dev"}}}}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&BasicFoos{})) + s.Name, s.DataType, s.Cardinality = "Foos", "*Foos", state.Many + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + { + Name: "CurFoos", + Mode: view.ModeQuery, + Template: &view.Template{Source: "SELECT * FROM FOOS\nWHERE $criteria.In(\"ID\", $Unsafe.CurFoosId.Values)"}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&BasicFoos{})) + s.Name, s.DataType, s.Cardinality = "Foos", "*Foos", state.Many + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + }, + extraInput: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + Tag: `codec:"structql,uri=foos/cur_foos_id.sql"`, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`, + Schema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + }, + }, + }, + }) + assertMutableDSQLParity(t, result.VeltyFilePath, "/Users/awitas/go/src/github.com/viant/datly/e2e/local/dql/generate_patch_basic_many/patch_basic_many.sql") + assertMutableSQLFileParity(t, filepath.Join(packageDir, "foos", "cur_foos_id.sql"), "/Users/awitas/go/src/github.com/viant/datly/e2e/local/pkg/dev/generate_patch_basic_many/foos/cur_foos_id.sql") + assertMutableSQLFileParity(t, filepath.Join(packageDir, "foos", "cur_foos.sql"), "/Users/awitas/go/src/github.com/viant/datly/e2e/local/pkg/dev/generate_patch_basic_many/foos/cur_foos.sql") +} + +func TestComponentCodegen_MutableComponent_DSQLParity_ManyMany(t *testing.T) { + result, packageDir := generateMutableFixture(t, mutableFixtureSpec{ + packageName: "generate_patch_many_many", + method: "PATCH", + uri: "/v1/api/dev/basic/foos-many-many", + bodySchema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + outputSchema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + views: []*view.View{ + { + Name: "Foos", + Mode: view.ModeExec, + Connector: &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Reference: shared.Reference{Ref: "dev"}}}}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&Foos{})) + s.Name, s.DataType, s.Cardinality = "Foos", "*Foos", state.Many + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + { + Name: "CurFoos", + Mode: view.ModeQuery, + Template: &view.Template{Source: "SELECT * FROM FOOS\nWHERE $criteria.In(\"ID\", $Unsafe.CurFoosId.Values)"}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&Foos{})) + s.Name, s.DataType, s.Cardinality = "Foos", "*Foos", state.Many + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + { + Name: "CurFoosPerformance", + Mode: view.ModeQuery, + Template: &view.Template{Source: "SELECT * FROM FOOS_PERFORMANCE\nWHERE $criteria.In(\"ID\", $Unsafe.CurFoosFoosPerformanceId.Values)"}, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(&FoosPerformance{})) + s.Name, s.DataType, s.Cardinality = "FoosPerformance", "*FoosPerformance", state.Many + return s + }(), + Columns: []*view.Column{{Name: "ID", DataType: "int"}, {Name: "FOO_ID", DataType: "int"}, {Name: "NAME", DataType: "string", Nullable: true}}, + }, + }, + extraInput: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + Tag: `codec:"structql,uri=foos/cur_foos_id.sql"`, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`, + Schema: &state.Schema{Name: "Foos", DataType: "*Foos", Cardinality: state.Many}, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoosFoosPerformanceId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + Tag: `codec:"structql,uri=foos/cur_foos_foos_performance_id.sql"`, + }, + }, + { + Parameter: state.Parameter{ + Name: "CurFoosPerformance", + In: state.NewViewLocation("CurFoosPerformance"), + Tag: `view:"CurFoosPerformance" sql:"uri=foos/cur_foos_performance.sql"`, + Schema: &state.Schema{Name: "FoosPerformance", DataType: "*FoosPerformance", Cardinality: state.Many}, + }, + }, + }, + }) + assertMutableDSQLParity(t, result.VeltyFilePath, "/Users/awitas/go/src/github.com/viant/datly/e2e/local/dql/generate_patch_many_many/patch_basic_many_many.sql") + assertMutableSQLFileParity(t, filepath.Join(packageDir, "foos", "cur_foos_id.sql"), "/Users/awitas/go/src/github.com/viant/datly/e2e/local/pkg/dev/generate_patch_basic_many/foos/cur_foos_id.sql") + assertMutableSQLFileParity(t, filepath.Join(packageDir, "foos", "cur_foos.sql"), "/Users/awitas/go/src/github.com/viant/datly/e2e/local/pkg/dev/generate_patch_basic_many/foos/cur_foos.sql") + if !strings.Contains(mustReadCodegenFile(t, filepath.Join(packageDir, "foos", "cur_foos_foos_performance_id.sql")), "SELECT ARRAY_AGG(Id) AS Values FROM `/FoosPerformance` LIMIT 1") { + t.Fatalf("expected nested current-ids helper SQL") + } + if !strings.Contains(mustReadCodegenFile(t, filepath.Join(packageDir, "foos", "cur_foos_performance.sql")), "SELECT * FROM FOOS_PERFORMANCE") { + t.Fatalf("expected nested current-view helper SQL") + } +} + +func TestComponentCodegen_MutableComponent_UsesResourceViewKeyTypeForIndexMap(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendorsvc", "update") + + type legacyRecords struct { + Id string + } + + component := &shapeload.Component{ + Method: "POST", + URI: "/v1/api/shape/dev/auth/products/", + RootView: "ProductUpdate", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Ids", + In: state.NewBodyLocation("Ids"), + Schema: state.NewSchema(reflect.TypeOf([]int{})), + }, + }, + { + Parameter: state.Parameter{ + Name: "Records", + In: state.NewViewLocation("Records"), + Tag: `view:"Records" sql:"uri=product_update/Records.sql"`, + Schema: &state.Schema{Name: "RecordsView", DataType: "*RecordsView", Cardinality: state.Many}, + }, + }, + }, + Output: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Status", + In: state.NewOutputLocation("status"), + Schema: state.NewSchema(reflect.TypeOf("")), + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "ProductUpdate", + Mode: view.ModeExec, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf(struct{}{})) + s.Name, s.DataType = "ProductUpdateView", "*ProductUpdateView" + return s + }(), + }, + &view.View{ + Name: "Records", + Mode: view.ModeQuery, + Schema: func() *state.Schema { + s := state.NewSchema(reflect.TypeOf([]*legacyRecords{})) + s.Name, s.DataType, s.Cardinality = "RecordsView", "*RecordsView", state.Many + return s + }(), + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "update", + PackagePath: "github.com/acme/project/shape/dev/vendorsvc/update", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: false, + WithContract: false, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + inputSource := mustReadCodegenFile(t, result.InputFilePath) + if !strings.Contains(inputSource, `RecordsById map[int]*RecordsView`) { + t.Fatalf("expected generated input to use resource view key type for index map:\n%s", inputSource) + } + initSource := mustReadCodegenFile(t, filepath.Join(packageDir, "input_init.go")) + if !strings.Contains(initSource, `i.RecordsById = make(map[int]*RecordsView, len(i.Records))`) { + t.Fatalf("expected generated init helper to use int map key:\n%s", initSource) + } + if !strings.Contains(initSource, `i.RecordsById[item.Id] = item`) { + t.Fatalf("expected generated init helper to index by int key:\n%s", initSource) + } +} + +func mustReadCodegenFile(t *testing.T, path string) string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s: %v", path, err) + } + return string(data) +} + +type mutableFixtureSpec struct { + packageName string + method string + uri string + bodySchema *state.Schema + outputSchema *state.Schema + views []*view.View + extraInput []*shapeplan.State +} + +func generateMutableFixture(t *testing.T, spec mutableFixtureSpec) (*ComponentCodegenResult, string) { + t.Helper() + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", spec.packageName) + component := &shapeload.Component{ + Method: spec.method, + URI: spec.uri, + RootView: "Foos", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: spec.bodySchema, + }, + }, + }, + Output: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewOutputLocation("body"), + Tag: `anonymous:"true"`, + Schema: spec.outputSchema, + }, + }, + }, + } + component.Input = append(component.Input, spec.extraInput...) + resource := view.EmptyResource() + resource.Views = append(resource.Views, spec.views...) + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: spec.packageName, + PackagePath: "github.com/acme/project/shape/dev/" + spec.packageName, + } + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + return result, packageDir +} + +func assertMutableDSQLParity(t *testing.T, actualPath, expectedPath string) { + t.Helper() + actual := normalizeMutableSQL(mustReadCodegenFile(t, actualPath)) + expected := normalizeMutableSQL(mustReadCodegenFile(t, expectedPath)) + if actual != expected { + t.Fatalf("mutable DSQL mismatch\nexpected:\n%s\n\nactual:\n%s", mustReadCodegenFile(t, expectedPath), mustReadCodegenFile(t, actualPath)) + } +} + +func assertMutableSQLFileParity(t *testing.T, actualPath, expectedPath string) { + t.Helper() + actual := normalizeMutableSQL(mustReadCodegenFile(t, actualPath)) + expected := normalizeMutableSQL(mustReadCodegenFile(t, expectedPath)) + if actual != expected { + t.Fatalf("mutable helper SQL mismatch for %s\nexpected:\n%s\n\nactual:\n%s", actualPath, mustReadCodegenFile(t, expectedPath), mustReadCodegenFile(t, actualPath)) + } +} + +func normalizeMutableSQL(value string) string { + value = strings.ReplaceAll(value, "\r\n", "\n") + lines := strings.Split(value, "\n") + out := make([]string, 0, len(lines)) + ws := regexp.MustCompile(`\s+`) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + line = ws.ReplaceAllString(line, " ") + out = append(out, line) + } + return strings.Join(out, "\n") +} + +func TestComponentCodegen_MutableComponent_MergesRootTemplateHelpersIntoInput(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "patch_basic_one") + + component := &shapeload.Component{ + Method: "PATCH", + URI: "/v1/api/shape/dev/basic/foos", + RootView: "Foos", + Input: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "FoosView", DataType: "*FoosView", Cardinality: state.One}, + }, + }, + }, + Output: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Foos", + In: state.NewOutputLocation("body"), + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "FoosView", DataType: "*FoosView", Cardinality: state.One}, + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "Foos", + Mode: view.ModeExec, + Schema: &state.Schema{ + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.One, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + {Name: "QUANTITY", DataType: "int", Nullable: true}, + }, + Template: &view.Template{ + UseParameterStateType: true, + Parameters: state.Parameters{ + { + Name: "Foos", + In: state.NewBodyLocation(""), + Tag: `anonymous:"true"`, + Schema: &state.Schema{Name: "FoosView", DataType: "*FoosView", Cardinality: state.One}, + }, + { + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + Tag: `codec:"structql,uri=foos/cur_foos_id.sql"`, + }, + { + Name: "CurFoos", + In: state.NewViewLocation("CurFoos"), + Tag: `view:"CurFoos" sql:"uri=foos/cur_foos.sql"`, + Schema: &state.Schema{Name: "FoosView", DataType: "*FoosView", Cardinality: state.Many}, + }, + }, + }, + }, + &view.View{ + Name: "CurFoos", + Mode: view.ModeQuery, + Schema: &state.Schema{ + Name: "FoosView", + DataType: "*FoosView", + Cardinality: state.Many, + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string", Nullable: true}, + {Name: "QUANTITY", DataType: "int", Nullable: true}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "patch_basic_one", + PackagePath: "github.com/acme/project/shape/dev/patch_basic_one", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + + inputSource := mustReadCodegenFile(t, result.InputFilePath) + for _, fragment := range []string{ + `CurFoosId *struct {`, + `Values []int`, + `CurFoos `, + `CurFoosById map[int]`, + } { + if !strings.Contains(inputSource, fragment) { + t.Fatalf("expected generated input to include %q:\n%s", fragment, inputSource) + } + } + + initSource := mustReadCodegenFile(t, filepath.Join(packageDir, "input_init.go")) + if !strings.Contains(initSource, `i.CurFoosById = make(map[int]FoosView, len(i.CurFoos))`) { + t.Fatalf("expected generated init helper to index CurFoos:\n%s", initSource) + } +} diff --git a/repository/shape/xgen/codegen_output_view_test.go b/repository/shape/xgen/codegen_output_view_test.go new file mode 100644 index 000000000..d93dbf53e --- /dev/null +++ b/repository/shape/xgen/codegen_output_view_test.go @@ -0,0 +1,79 @@ +package xgen + +import ( + "os" + "path/filepath" + "strings" + "testing" + + shapeload "github.com/viant/datly/repository/shape/load" + shapeplan "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestComponentCodegen_PreservesExplicitOutputViewOneCardinality(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "user_acl") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/auth/user-acl", + RootView: "user_acl", + Output: []*shapeplan.State{ + { + Parameter: state.Parameter{ + Name: "Data", + In: state.NewOutputLocation("view"), + Tag: `anonymous:"true"`, + Schema: &state.Schema{ + Cardinality: state.One, + }, + }, + }, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{ + Name: "user_acl", + Table: "USER_ACL", + Template: &view.Template{SourceURL: "user_acl/user_acl.sql"}, + Schema: &state.Schema{Name: "UserAclView", DataType: "*UserAclView", Cardinality: state.Many}, + Columns: []*view.Column{ + {Name: "UserID", DataType: "int"}, + }, + }) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "user_acl", + PackagePath: "github.com/acme/project/shape/dev/vendor/user_acl", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: false, + WithContract: false, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if !strings.Contains(generated, "Data *UserAclView `parameter:\",kind=output,in=view\" view:\"user_acl\" sql:\"uri=user_acl/user_acl.sql\" anonymous:\"true\"`") { + t.Fatalf("expected one-cardinality output view to generate pointer field, got:\n%s", generated) + } + if strings.Contains(generated, "Data []*UserAclView") { + t.Fatalf("expected output view not to generate slice field, got:\n%s", generated) + } +} diff --git a/repository/shape/xgen/codegen_placeholder_view_test.go b/repository/shape/xgen/codegen_placeholder_view_test.go new file mode 100644 index 000000000..cf417a6f3 --- /dev/null +++ b/repository/shape/xgen/codegen_placeholder_view_test.go @@ -0,0 +1,83 @@ +package xgen + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestRebuildResourceViewStructType_ReplacesPlaceholderColumnsPreservesRelations(t *testing.T) { + type placeholderProducts struct { + ID int `sqlx:"ID"` + } + type placeholderVendor struct { + Col1 string `sqlx:"name=col_1"` + Col2 string `sqlx:"name=col_2"` + Products []*placeholderProducts `view:",table=PRODUCT" sql:"uri=vendor/products.sql" sqlx:"-"` + } + + cols := []columnDescriptor{ + {name: "ID", dataType: "int", primaryKey: true}, + {name: "NAME", dataType: "string"}, + } + + rType := rebuildResourceViewStructType(reflect.TypeOf(placeholderVendor{}), cols, false) + require.NotNil(t, rType) + require.Equal(t, reflect.Struct, rType.Kind()) + + field, ok := rType.FieldByName("Id") + require.True(t, ok) + require.Equal(t, "ID", sqlxTagName(field.Tag.Get("sqlx"))) + + field, ok = rType.FieldByName("Name") + require.True(t, ok) + require.Equal(t, "NAME", sqlxTagName(field.Tag.Get("sqlx"))) + + _, ok = rType.FieldByName("Col1") + require.False(t, ok) + + field, ok = rType.FieldByName("Products") + require.True(t, ok) + require.Equal(t, `uri=vendor/products.sql`, field.Tag.Get("sql")) + require.Equal(t, ",table=PRODUCT", field.Tag.Get("view")) +} + +func TestComponentCodegen_UsesDiscoveredColumnsWhenRootTypeIsPlaceholder(t *testing.T) { + type placeholderProducts struct { + ID int `sqlx:"ID"` + } + type placeholderVendor struct { + Col1 string `sqlx:"name=col_1"` + Products []*placeholderProducts `view:",table=PRODUCT" sql:"uri=vendor/products.sql" sqlx:"-"` + } + + resource := view.EmptyResource() + resource.Views = view.Views{ + { + Name: "vendor", + Schema: &state.Schema{ + Name: "VendorView", + }, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "NAME", DataType: "string"}, + }, + }, + } + resource.Views[0].Schema.SetType(reflect.TypeOf([]placeholderVendor{})) + + codegen := &ComponentCodegen{Resource: resource} + rType := codegen.resourceViewStructType("vendor") + require.NotNil(t, rType) + _, ok := rType.FieldByName("Id") + require.True(t, ok) + _, ok = rType.FieldByName("Name") + require.True(t, ok) + _, ok = rType.FieldByName("Products") + require.True(t, ok) + _, ok = rType.FieldByName("Col1") + require.False(t, ok) +} diff --git a/repository/shape/xgen/codegen_relation_view_test.go b/repository/shape/xgen/codegen_relation_view_test.go new file mode 100644 index 000000000..d89a4e976 --- /dev/null +++ b/repository/shape/xgen/codegen_relation_view_test.go @@ -0,0 +1,183 @@ +package xgen + +import ( + "os" + "path/filepath" + "strings" + "testing" + + shapeload "github.com/viant/datly/repository/shape/load" + shapeplan "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestComponentCodegen_UsesMaterializedViewTypeForRelationHolders(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "details") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/vendors/{vendorID}", + RootView: "vendor", + Output: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "Data", In: state.NewOutputLocation("view"), Schema: &state.Schema{Cardinality: state.Many}}}, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "vendor", + Table: "VENDOR", + Template: &view.Template{SourceURL: "wrapper/vendor.sql"}, + Schema: &state.Schema{Name: "VendorView", DataType: "*VendorView", Cardinality: state.Many}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + With: []*view.Relation{ + { + Holder: "Products", + Cardinality: state.Many, + On: view.Links{&view.Link{Field: "Id", Column: "ID"}}, + Of: &view.ReferenceView{ + View: view.View{ + Name: "products", + Table: "PRODUCT", + Template: &view.Template{SourceURL: "wrapper/products.sql"}, + Schema: &state.Schema{Name: "ProductsView", DataType: "*ProductsView", Cardinality: state.Many}, + }, + On: view.Links{&view.Link{Field: "VendorId", Column: "VENDOR_ID"}}, + }, + }, + }, + }, + &view.View{ + Name: "products", + Table: "PRODUCT", + Template: &view.Template{SourceURL: "wrapper/products.sql"}, + Schema: &state.Schema{Name: "ProductsView", DataType: "*ProductsView", Cardinality: state.Many}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "VENDOR_ID", DataType: "*int", Tag: `internal:"true"`}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "details", + PackagePath: "github.com/acme/project/shape/dev/vendor/details", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: false, + WithContract: false, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if !strings.Contains(generated, "Products []*ProductsView `view:\",table=PRODUCT\" on:\"Id:ID=VendorId:VENDOR_ID\" sql:\"uri=wrapper/products.sql\"`") { + t.Fatalf("expected generated VendorView to use named relation holder field, got:\n%s", generated) + } + if strings.Contains(generated, "*struct {") { + t.Fatalf("expected no anonymous relation structs, got:\n%s", generated) + } + if strings.Contains(generated, "table=(SELECT") { + t.Fatalf("expected no raw subquery text in relation view tag, got:\n%s", generated) + } +} + +func TestComponentCodegen_RelationHolderUsesGeneratedResolvedTypeName(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "list") + + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/vendors/", + RootView: "vendor", + TypeSpecs: map[string]*shapeload.TypeSpec{ + "view:products": {Key: "view:products", Role: shapeload.TypeRoleView, Alias: "products", TypeName: "Products"}, + }, + Output: []*shapeplan.State{ + {Parameter: state.Parameter{Name: "Data", In: state.NewOutputLocation("view"), Schema: &state.Schema{Cardinality: state.Many}}}, + }, + } + + resource := view.EmptyResource() + resource.Views = append(resource.Views, + &view.View{ + Name: "vendor", + Table: "VENDOR", + Template: &view.Template{SourceURL: "vendor/vendor.sql"}, + Schema: &state.Schema{Name: "Vendor", DataType: "*Vendor", Cardinality: state.Many}, + Columns: []*view.Column{{Name: "ID", DataType: "int"}}, + With: []*view.Relation{ + { + Holder: "Products", + Cardinality: state.Many, + On: view.Links{&view.Link{Field: "Id", Column: "ID"}}, + Of: &view.ReferenceView{ + View: view.View{ + Name: "products", + Table: "PRODUCT", + Template: &view.Template{SourceURL: "vendor/products.sql"}, + Schema: &state.Schema{Name: "ProductsView", DataType: "*ProductsView", Cardinality: state.Many}, + }, + On: view.Links{&view.Link{Field: "VendorId", Column: "VENDOR_ID"}}, + }, + }, + }, + }, + &view.View{ + Name: "products", + Table: "PRODUCT", + Template: &view.Template{SourceURL: "vendor/products.sql"}, + Schema: &state.Schema{Name: "ProductsView", DataType: "*ProductsView", Cardinality: state.Many}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + {Name: "VENDOR_ID", DataType: "*int", Tag: `internal:"true"`}, + }, + }, + ) + + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "list", + PackagePath: "github.com/acme/project/shape/dev/vendor/list", + } + + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: false, + WithContract: false, + } + + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + generated := string(data) + if !strings.Contains(generated, "Products []*Products `view:") { + t.Fatalf("expected generated relation holder to use resolved generated child type, got:\n%s", generated) + } +} diff --git a/repository/shape/xgen/codegen_typespec_test.go b/repository/shape/xgen/codegen_typespec_test.go new file mode 100644 index 000000000..41c7910ac --- /dev/null +++ b/repository/shape/xgen/codegen_typespec_test.go @@ -0,0 +1,87 @@ +package xgen + +import ( + "os" + "path/filepath" + "strings" + "testing" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + shapeload "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" +) + +func TestComponentCodegen_TypeSpecs_InputOutputAndDest(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "pkg", "dev", "vendor") + component := &shapeload.Component{ + Method: "GET", + URI: "/v1/api/shape/dev/vendors/", + RootView: "vendor", + Directives: &dqlshape.Directives{ + Dest: "all.go", + RouterDest: "vendor_router.go", + }, + TypeSpecs: map[string]*shapeload.TypeSpec{ + "input": {Key: "input", Role: shapeload.TypeRoleInput, TypeName: "VendorReq"}, + "output": {Key: "output", Role: shapeload.TypeRoleOutput, TypeName: "VendorResp", Dest: "vendor_output.go"}, + "view:vendor": {Key: "view:vendor", Role: shapeload.TypeRoleView, Alias: "vendor", TypeName: "Vendor"}, + }, + } + resource := &view.Resource{ + Views: []*view.View{ + { + Name: "vendor", + Connector: &view.Connector{Connection: view.Connection{DBConfig: view.DBConfig{Name: "dev"}}}, + Columns: []*view.Column{ + {Name: "ID", DataType: "int"}, + }, + }, + }, + } + ctx := &typectx.Context{ + PackageDir: packageDir, + PackageName: "vendor", + PackagePath: "github.com/acme/project/pkg/dev/vendor", + } + codegen := &ComponentCodegen{ + Component: component, + Resource: resource, + TypeContext: ctx, + ProjectDir: projectDir, + WithEmbed: true, + WithContract: true, + } + result, err := codegen.Generate() + if err != nil { + t.Fatalf("generate: %v", err) + } + if filepath.Base(result.FilePath) != "vendor_output.go" { + t.Fatalf("expected destination override vendor_output.go, got %s", filepath.Base(result.FilePath)) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file: %v", err) + } + source := string(data) + expectContainsTypeSpec(t, source, "type VendorReq struct") + expectContainsTypeSpec(t, source, "type VendorResp struct") + expectContainsTypeSpec(t, source, "Data []*Vendor") + expectContainsTypeSpec(t, source, "reflect.TypeOf(VendorReq{})") + expectContainsTypeSpec(t, source, "reflect.TypeOf(VendorResp{})") + routerData, err := os.ReadFile(filepath.Join(packageDir, "vendor_router.go")) + if err != nil { + t.Fatalf("read generated router file: %v", err) + } + routerSource := string(routerData) + expectContainsTypeSpec(t, routerSource, "type VendorRouter struct") + expectContainsTypeSpec(t, routerSource, "Vendor xdatly.Component[VendorReq, VendorResp]") +} + +func expectContainsTypeSpec(t *testing.T, source string, fragment string) { + t.Helper() + if !strings.Contains(source, fragment) { + t.Fatalf("expected generated source to contain %q\nsource:\n%s", fragment, source) + } +} diff --git a/repository/shape/xgen/generator.go b/repository/shape/xgen/generator.go index d0fc419a8..0d0d59110 100644 --- a/repository/shape/xgen/generator.go +++ b/repository/shape/xgen/generator.go @@ -12,10 +12,12 @@ import ( "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/repository/shape/typectx/source" + "github.com/viant/datly/view" "github.com/viant/x" xreflectloader "github.com/viant/x/loader/xreflect" "github.com/viant/x/syntetic" "github.com/viant/x/syntetic/model" + "github.com/viant/xunsafe" ) // GenerateFromDQLShape emits Go structs from DQL shape using viant/x registry. @@ -52,24 +54,48 @@ func GenerateFromDQLShape(doc *shape.Document, cfg *Config) (*Result, error) { } typeNames := make([]string, 0, len(views)+len(routeTypes)) registered := map[string]bool{} + includeVelty := documentUsesVelty(doc) for _, view := range views { typeName := viewTypeName(cfg, view) if registered[typeName] { continue } + structType := buildStructType(view.columns, includeVelty) + if structType == nil { + continue + } registered[typeName] = true - if err = registerShapeType(registry, packagePath, typeName, buildStructType(view.columns)); err != nil { + if err = registerShapeType(registry, packagePath, typeName, structType); err != nil { return nil, err } typeNames = append(typeNames, typeName) + + // Generate Has marker struct for mutable views + if view.mutable && len(view.columns) > 0 { + hasTypeName := typeName + "Has" + if !registered[hasTypeName] { + hasType := buildHasType(view.columns) + if hasType != nil { + registered[hasTypeName] = true + if err = registerShapeType(registry, packagePath, hasTypeName, hasType); err != nil { + return nil, err + } + typeNames = append(typeNames, hasTypeName) + } + } + } } for _, ioType := range routeTypes { typeName := routeTypeName(cfg, ioType) if typeName == "" || registered[typeName] { continue } + structType := buildStructType(ioType.fields, includeVelty) + if structType == nil { + continue + } registered[typeName] = true - if err = registerShapeType(registry, packagePath, typeName, buildStructType(ioType.fields)); err != nil { + if err = registerShapeType(registry, packagePath, typeName, structType); err != nil { return nil, err } typeNames = append(typeNames, typeName) @@ -88,7 +114,9 @@ func GenerateFromDQLShape(doc *shape.Document, cfg *Config) (*Result, error) { if goFile == nil { return nil, fmt.Errorf("shape xgen: missing generated package file for %s", packagePath) } - source, err := goFile.Render() + source, err := goFile.RenderWithOptions(model.RenderOptions{ + Header: "// Code generated by datly transcribe. DO NOT EDIT.", + }) if err != nil { return nil, err } @@ -285,7 +313,50 @@ func uniqueStrings(items []string) []string { return result } +// registerShapeType registers a type in the x.Registry. If a type with the same +// name already exists (linked-in binary type via xunsafe, or previously registered +// in the registry), it preserves the existing field order and appends new fields. func registerShapeType(registry *x.Registry, packagePath string, typeName string, rType reflect.Type) error { + existingType := lookupExistingType(registry, packagePath, typeName) + if existingType != nil && existingType.Kind() == reflect.Struct && rType.Kind() == reflect.Struct { + var newFields []reflect.StructField + for i := 0; i < rType.NumField(); i++ { + newFields = append(newFields, rType.Field(i)) + } + rType = reflect.StructOf(mergeFieldOrder(existingType, newFields)) + } + return registerShapeTypeRaw(registry, packagePath, typeName, rType) +} + +// lookupExistingType checks for a previously known type: +// 1. First checks linked-in types via xunsafe.LookupType (compiled binary types) +// 2. Then checks the x.Registry (previously registered synthetic types) +// Returns the unwrapped struct reflect.Type, or nil if not found. +func lookupExistingType(registry *x.Registry, packagePath, typeName string) reflect.Type { + // xunsafe indexes by "pkgPath/TypeName" + if linked := xunsafe.LookupType(packagePath + "/" + typeName); linked != nil { + for linked.Kind() == reflect.Ptr || linked.Kind() == reflect.Slice { + linked = linked.Elem() + } + if linked.Kind() == reflect.Struct { + return linked + } + } + // Check registry + key := packagePath + "." + typeName + if existing := registry.Lookup(key); existing != nil && existing.Type != nil { + t := existing.Type + for t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + return t + } + } + return nil +} + +func registerShapeTypeRaw(registry *x.Registry, packagePath string, typeName string, rType reflect.Type) error { st, err := xreflectloader.BuildType(rType, xreflectloader.WithPackagePath(packagePath), xreflectloader.WithNamePolicy(func(reflect.Type) (string, bool) { @@ -310,6 +381,7 @@ type viewDescriptor struct { name any schemaName any columns []columnDescriptor + mutable bool } type ioTypeKind string @@ -329,8 +401,11 @@ type routeIODescriptor struct { } type columnDescriptor struct { - name string - dataType string + name string + dataType string + primaryKey bool + autoIncrement bool + nullable bool } func extractViews(root map[string]any) []viewDescriptor { @@ -353,6 +428,10 @@ func extractViews(root map[string]any) []viewDescriptor { if schema != nil { descriptor.schemaName = schema["Name"] } + mode := strings.ToLower(asString(view["Mode"])) + if mode == "sqlexec" || mode == "exec" || mode == "handler" { + descriptor.mutable = true + } descriptor.columns = extractColumns(view) result = append(result, descriptor) } @@ -371,7 +450,29 @@ func extractColumns(view map[string]any) []columnDescriptor { if name == "" { continue } - result = append(result, columnDescriptor{name: name, dataType: asString(column["DataType"])}) + col := columnDescriptor{ + name: name, + dataType: asString(column["DataType"]), + } + // Check for primary key / autoincrement from column tag or metadata + tag := strings.ToLower(asString(column["Tag"])) + if strings.Contains(tag, "primarykey") || strings.Contains(tag, "primary_key") { + col.primaryKey = true + } + if strings.Contains(tag, "autoincrement") || strings.Contains(tag, "auto_increment") { + col.autoIncrement = true + } + // Heuristic: column named ID or ending in _ID at position 0 is likely PK + if strings.EqualFold(name, "ID") && len(result) == 0 { + col.primaryKey = true + if strings.EqualFold(col.dataType, "int") || strings.EqualFold(col.dataType, "integer") || strings.EqualFold(col.dataType, "int64") { + col.autoIncrement = true + } + } + if asBool(column["Nullable"]) { + col.nullable = true + } + result = append(result, col) } } if cfg := asMap(view["ColumnsConfig"]); len(cfg) > 0 { @@ -386,12 +487,16 @@ func extractColumns(view map[string]any) []columnDescriptor { item = map[string]any{} } name := firstNonEmpty(asString(item["Name"]), key) - result = append(result, columnDescriptor{name: name, dataType: asString(item["DataType"])}) + col := columnDescriptor{name: name, dataType: asString(item["DataType"])} + if strings.EqualFold(name, "ID") && len(result) == 0 { + col.primaryKey = true + if strings.EqualFold(col.dataType, "int") || col.dataType == "" { + col.autoIncrement = true + } + } + result = append(result, col) } } - if len(result) == 0 { - result = append(result, columnDescriptor{name: "ID", dataType: "int"}) - } return result } @@ -458,15 +563,163 @@ func extractIOFields(io map[string]any) []columnDescriptor { } fields = append(fields, columnDescriptor{name: name, dataType: dataType}) } + return fields +} + +// mergeFieldOrder preserves existing field positions from a previously registered type. +// Existing fields keep their index; new fields are appended; removed fields are kept. +func mergeFieldOrder(existing reflect.Type, newFields []reflect.StructField) []reflect.StructField { + if existing == nil || existing.Kind() != reflect.Struct { + return newFields + } + // Index new fields by name + newByName := map[string]reflect.StructField{} + for _, f := range newFields { + newByName[f.Name] = f + } + // Start with existing fields in order, updating type/tag if regenerated + var merged []reflect.StructField + seen := map[string]bool{} + for i := 0; i < existing.NumField(); i++ { + ef := existing.Field(i) + seen[ef.Name] = true + if nf, ok := newByName[ef.Name]; ok { + // Keep position. If existing field has an explicit type override + // (from DQL cast like CAST(col AS MyType)), preserve the existing + // type — the user's intent takes precedence over DB discovery. + if hasExplicitTypeOverride(ef) { + nf.Type = ef.Type + } + merged = append(merged, nf) + } else { + // Column removed from DB — keep field for stability + merged = append(merged, ef) + } + } + // Append genuinely new fields + for _, nf := range newFields { + if !seen[nf.Name] { + merged = append(merged, nf) + } + } + return merged +} + +// hasExplicitTypeOverride returns true if the field's type was explicitly set +// by a DQL cast (e.g., CAST(col AS MyType)) rather than inferred from DB discovery. +// Detected via typeName tag or non-primitive type that isn't a standard DB mapping. +func hasExplicitTypeOverride(field reflect.StructField) bool { + tag := field.Tag + // typeName tag indicates an explicit type name was set + if v := tag.Get("typeName"); v != "" { + return true + } + // Check if the type is a named type (not a primitive or pointer-to-primitive) + ft := field.Type + for ft.Kind() == reflect.Ptr || ft.Kind() == reflect.Slice { + ft = ft.Elem() + } + if ft.PkgPath() != "" && ft.Kind() == reflect.Struct { + // Named struct from a package = explicit type (e.g., jwt.Claims, time.Time) + name := ft.Name() + if name != "" && name != "Time" { // time.Time is a standard DB mapping + return true + } + } + return false +} + +func buildStructType(columns []columnDescriptor, includeVelty bool) reflect.Type { + fields := buildStructFields(columns, includeVelty) if len(fields) == 0 { - fields = append(fields, columnDescriptor{name: "ID", dataType: "int"}) + return nil + } + return reflect.StructOf(fields) +} + +func buildStructFields(columns []columnDescriptor, includeVelty bool) []reflect.StructField { + if len(columns) == 0 { + return nil + } + fields := make([]reflect.StructField, 0, len(columns)) + used := map[string]int{} + for _, column := range columns { + fieldName := exportedName(column.name) + if fieldName == "" { + fieldName = "Field" + } + if count := used[fieldName]; count > 0 { + fieldName = fmt.Sprintf("%s%d", fieldName, count+1) + } + used[fieldName]++ + fieldType := parseType(column.dataType) + sqlxTag := column.name + isPK := column.primaryKey + isAutoInc := column.autoIncrement + if isPK { + sqlxTag += ",primaryKey" + } + if isAutoInc { + sqlxTag += ",autoincrement" + } + // Use pointer types for nullable (non-PK) columns to match legacy codegen + if !isPK && !isAutoInc && fieldType.Kind() != reflect.Slice && fieldType.Kind() != reflect.Ptr { + fieldType = reflect.PointerTo(fieldType) + } + tag := fmt.Sprintf(`sqlx:"%s" json:",omitempty"`, sqlxTag) + if isPK || isAutoInc { + tag = fmt.Sprintf(`sqlx:"%s"`, sqlxTag) + } + if includeVelty { + veltyNames := []string{column.name} + if fieldName != "" && fieldName != column.name { + veltyNames = append(veltyNames, fieldName) + } + veltyTag := fmt.Sprintf(`velty:"names=%s"`, strings.Join(veltyNames, "|")) + tag = tag + " " + veltyTag + } + fields = append(fields, reflect.StructField{ + Name: fieldName, + Type: fieldType, + Tag: reflect.StructTag(tag), + }) } return fields } -func buildStructType(columns []columnDescriptor) reflect.Type { +func documentUsesVelty(doc *shape.Document) bool { + if doc == nil || doc.Root == nil { + return false + } + var visit func(value any) bool + visit = func(value any) bool { + switch actual := value.(type) { + case map[string]any: + if mode := strings.TrimSpace(asString(actual["Mode"])); mode == string(view.ModeExec) { + return true + } + for _, item := range actual { + if visit(item) { + return true + } + } + case []any: + for _, item := range actual { + if visit(item) { + return true + } + } + } + return false + } + return visit(doc.Root) +} + +// buildHasType creates a marker struct with bool fields for each column. +// Used by mutable views to track which fields were explicitly set. +func buildHasType(columns []columnDescriptor) reflect.Type { if len(columns) == 0 { - columns = []columnDescriptor{{name: "ID", dataType: "int"}} + return nil } fields := make([]reflect.StructField, 0, len(columns)) used := map[string]int{} @@ -481,8 +734,7 @@ func buildStructType(columns []columnDescriptor) reflect.Type { used[fieldName]++ fields = append(fields, reflect.StructField{ Name: fieldName, - Type: parseType(column.dataType), - Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"%s"`, strings.ToLower(fieldName), column.name)), + Type: reflect.TypeOf(true), }) } return reflect.StructOf(fields) @@ -673,6 +925,16 @@ func asSlice(raw any) []any { return nil } +func asBool(raw any) bool { + if raw == nil { + return false + } + if v, ok := raw.(bool); ok { + return v + } + return false +} + func asString(raw any) string { if raw == nil { return "" diff --git a/repository/shape/xgen/generator_velty_tag_test.go b/repository/shape/xgen/generator_velty_tag_test.go new file mode 100644 index 000000000..2bf0afbfd --- /dev/null +++ b/repository/shape/xgen/generator_velty_tag_test.go @@ -0,0 +1,48 @@ +package xgen + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuildStructType_AddsVeltyNamesFromSQLColumns(t *testing.T) { + rType := buildStructType([]columnDescriptor{ + {name: "IS_AUTH", dataType: "int"}, + }, true) + require.NotNil(t, rType) + if rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + field, ok := rType.FieldByName("IsAuth") + require.True(t, ok) + require.Equal(t, `names=IS_AUTH|IsAuth`, field.Tag.Get("velty")) + require.Equal(t, `IS_AUTH`, field.Tag.Get("sqlx")) +} + +func TestBuildStructType_DedupesVeltyNamesWhenGoFieldMatchesColumn(t *testing.T) { + rType := buildStructType([]columnDescriptor{ + {name: "UserID", dataType: "int"}, + }, true) + require.NotNil(t, rType) + if rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + field, ok := rType.FieldByName("UserID") + require.True(t, ok) + require.Equal(t, `names=UserID`, field.Tag.Get("velty")) +} + +func TestBuildStructType_OmitsVeltyWhenDisabled(t *testing.T) { + rType := buildStructType([]columnDescriptor{ + {name: "USER_ID", dataType: "int"}, + }, false) + require.NotNil(t, rType) + if rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + field := rType.Field(0) + require.Equal(t, "", field.Tag.Get("velty")) + require.Equal(t, `USER_ID`, field.Tag.Get("sqlx")) +} diff --git a/repository/shape/xgen/io.go b/repository/shape/xgen/io.go index 50e86ddaa..3d2ff7b3e 100644 --- a/repository/shape/xgen/io.go +++ b/repository/shape/xgen/io.go @@ -223,6 +223,13 @@ func mergeGeneratedShapes(dest string, generated []byte, typeNames []string) ([] return out.Bytes(), nil } +// TODO: Field order preservation should be done at the viant/x registry level: +// 1. Check linked-in types first (runtime reflect.Type from registered types) +// 2. Fall back to viant/x/loader/ast.LoadPackageFS to load existing .go file +// 3. Extract field order from loaded types +// 4. When building new types, preserve existing field order and append new fields +// This avoids raw AST manipulation and handles complex type graphs (nested structs, relations). + func generatedShapeDecls(file *ast.File, typeNameSet map[string]bool) []ast.Decl { var result []ast.Decl for _, decl := range file.Decls { diff --git a/repository/shape/xgen/mutable_body.go b/repository/shape/xgen/mutable_body.go new file mode 100644 index 000000000..be1edba0c --- /dev/null +++ b/repository/shape/xgen/mutable_body.go @@ -0,0 +1,1016 @@ +package xgen + +import ( + "path/filepath" + "reflect" + "regexp" + "sort" + "strings" + + shapeast "github.com/viant/datly/repository/shape/velty/ast" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" +) + +func (g *ComponentCodegen) renderMutableVeltyBody(inputType reflect.Type) (string, bool, error) { + if inputType == nil { + var err error + inputType, err = g.mutableInputType() + if err != nil { + return "", false, err + } + } + support := g.mutableSupport(inputType) + if support == nil { + return "", false, nil + } + block, err := g.buildMutableVeltyBlock(inputType, support) + if err != nil { + return "", false, err + } + if block == nil { + return "", false, nil + } + builder := shapeast.NewBuilder(shapeast.Options{Lang: shapeast.LangVelty}) + if err = block.Generate(builder); err != nil { + return "", false, err + } + body := strings.TrimSpace(builder.String()) + body = g.normalizeMutableBodyReferences(body, support) + return body + "\n", true, nil +} + +func (g *ComponentCodegen) renderMutableDSQL(inputType reflect.Type) (string, bool, error) { + body, ok, err := g.renderMutableVeltyBody(inputType) + if err != nil || !ok { + return "", ok, err + } + return strings.TrimSpace(body) + "\n", true, nil +} + +func (g *ComponentCodegen) mutableTypeImports(support *mutableComponentSupport, inputType reflect.Type) []string { + items := map[string]struct{}{} + add := func(typeName string) { + typeName = strings.TrimSpace(typeName) + if typeName == "" { + return + } + pkg := strings.TrimSpace(g.PackagePath) + if pkg == "" && g.TypeContext != nil { + pkg = strings.TrimSpace(g.TypeContext.PackagePath) + } + if pkg == "" { + return + } + items[pkg] = struct{}{} + } + if bodyField, ok := inputType.FieldByName(support.BodyFieldName); ok { + if itemType, _ := mutableBodyItemType(bodyField.Type); itemType != nil { + typeName := strings.TrimSpace(support.BodyTypeName) + if typeName == "" { + typeName = itemType.Name() + } + add(typeName) + } + } + for _, helper := range support.Helpers { + typeName := strings.TrimSpace(helper.TypeName) + if typeName == "" && helper.ItemStruct != nil { + typeName = helper.ItemStruct.Name() + } + if typeName != "" { + add(typeName) + } + } + result := make([]string, 0, len(items)) + for item := range items { + result = append(result, item) + } + sort.Strings(result) + return result +} + +func (g *ComponentCodegen) mutableRouteOptionJSON() string { + connector := g.rootConnectorRef() + parts := []string{ + `"URI":"` + escapeJSON(strings.TrimSpace(g.Component.URI)) + `"`, + `"Method":"` + escapeJSON(strings.ToUpper(strings.TrimSpace(g.Component.Method))) + `"`, + } + if connector != "" { + parts = append(parts, `"Connector":"`+escapeJSON(connector)+`"`) + } + return "{" + strings.Join(parts, ",") + "}" +} + +func (g *ComponentCodegen) mutableBodyDeclaration(inputType reflect.Type, support *mutableComponentSupport) string { + bodyField, ok := inputType.FieldByName(support.BodyFieldName) + if !ok { + return "" + } + cardinality := "" + if g.mutableBodyMany(bodyField, support) { + cardinality = ".Cardinality('Many')" + } + return "#set($_ = $" + support.BodyFieldName + "(body/)" + cardinality + ".WithTag('anonymous:\"true\"').Required())\n" +} + +func (g *ComponentCodegen) mutableIDsDeclaration(helper mutableIndexHelper) string { + paramName := g.mutableIDsParamName(helper) + sqlText := g.mutableIDSQL(helper) + if paramName == "" || sqlText == "" { + return "" + } + return "\t#set($_ = $" + paramName + "(param/" + g.supportBodyFieldName(helper) + ") /*\n" + sqlText + "\n*/\n)\n" +} + +func (g *ComponentCodegen) mutableViewDeclaration(helper mutableIndexHelper) string { + viewName := strings.TrimSpace(helper.ViewFieldName) + if viewName == "" { + return "" + } + sqlText := g.mutableDeclarationViewSQL(helper) + if sqlText == "" { + return "" + } + typeExpr := strings.TrimSpace(helper.ItemTypeExpr) + if typeExpr == "" { + return "" + } + if g.mutableHelperUsesMany(helper) && !strings.HasPrefix(typeExpr, "[]") { + typeExpr = "[]" + typeExpr + } + return "\t#set($_ = $" + viewName + "<" + typeExpr + ">(view/" + viewName + ") /*\n" + sqlText + "\n*/\n)\n" +} + +func (g *ComponentCodegen) mutableHelperUsesMany(helper mutableIndexHelper) bool { + if g == nil { + return true + } + lookup := func(params state.Parameters) (bool, bool) { + for _, input := range params { + if input == nil || input.In == nil || input.In.Kind != state.KindView { + continue + } + if !strings.EqualFold(strings.TrimSpace(input.Name), strings.TrimSpace(helper.ViewParamName)) { + continue + } + if input.Schema == nil { + return true, true + } + return input.Schema.Cardinality == state.Many, true + } + return false, false + } + if g.Component != nil { + componentInputs := make(state.Parameters, 0, len(g.Component.Input)) + for _, input := range g.Component.Input { + if input == nil { + continue + } + componentInputs = append(componentInputs, &input.Parameter) + } + if many, ok := lookup(componentInputs); ok { + return many + } + } + if root := g.rootResourceView(); root != nil && root.Template != nil { + if many, ok := lookup(root.Template.Parameters); ok { + return many + } + } + if g.Resource != nil { + if many, ok := lookup(g.Resource.Parameters); ok { + return many + } + } + return true +} + +func (g *ComponentCodegen) mutableOutputDeclaration(inputType reflect.Type, support *mutableComponentSupport) string { + bodyField, ok := inputType.FieldByName(support.BodyFieldName) + if !ok { + return "" + } + cardinality := "" + if g.mutableBodyMany(bodyField, support) { + cardinality = ".Cardinality('Many')" + } + tag := `anonymous:"true"` + typeName := strings.TrimSpace(support.BodyTypeName) + if typeName == "" { + if itemType, _ := mutableBodyItemType(bodyField.Type); itemType != nil { + typeName = itemType.Name() + } + } + if typeName != "" { + tag += ` typeName:"` + typeName + `"` + } + return "#set($_ = $" + support.BodyFieldName + "(body/)" + cardinality + ".WithTag('" + tag + "').Required().Output())\n" +} + +func (g *ComponentCodegen) mutableQualifiedTypeName(typeName string) string { + typeName = strings.TrimSpace(typeName) + if typeName == "" { + return "" + } + return typeName +} + +func (g *ComponentCodegen) mutableIDSQL(helper mutableIndexHelper) string { + key := strings.TrimSpace(helper.KeyFieldName) + if key == "" { + key = "Id" + } + path := "/" + if helper.RelationPath != "" { + path += helper.RelationPath + } + return "? SELECT ARRAY_AGG(" + key + ") AS Values FROM `" + path + "` LIMIT 1" +} + +func (g *ComponentCodegen) mutableViewSQL(helper mutableIndexHelper) string { + if g == nil || g.Resource == nil { + return g.normalizeMutableViewSQL(helper, g.mutableFallbackViewSQL(helper)) + } + for _, aView := range g.Resource.Views { + if aView == nil || !strings.EqualFold(strings.TrimSpace(aView.Name), strings.TrimSpace(helper.ViewParamName)) { + continue + } + if aView.Template == nil { + return g.normalizeMutableViewSQL(helper, g.mutableFallbackViewSQL(helper)) + } + sqlText := strings.TrimSpace(aView.Template.Source) + if sqlText != "" { + return g.normalizeMutableViewSQL(helper, sqlText) + } + return g.normalizeMutableViewSQL(helper, g.mutableFallbackViewSQL(helper)) + } + return g.normalizeMutableViewSQL(helper, g.mutableFallbackViewSQL(helper)) +} + +func (g *ComponentCodegen) mutableDeclarationViewSQL(helper mutableIndexHelper) string { + sqlText := strings.TrimSpace(g.mutableViewSQL(helper)) + if sqlText == "" { + return "" + } + if strings.HasPrefix(sqlText, "?") { + return sqlText + } + return "? " + sqlText +} + +func (g *ComponentCodegen) mutableFallbackViewSQL(helper mutableIndexHelper) string { + tableName := "" + if helper.ItemStruct != nil && helper.ItemStruct.Name() != "" { + tableName = tableNameFromType(helper.ItemStruct.Name()) + } + if tableName == "" { + typeName := strings.TrimPrefix(strings.TrimPrefix(strings.TrimSpace(helper.ItemTypeExpr), "[]"), "*") + tableName = tableNameFromType(typeName) + } + if tableName == "" { + return "" + } + idParam := g.mutableIDsParamName(helper) + key := strings.TrimSpace(helper.KeyFieldName) + if key == "" { + key = "Id" + } + return "SELECT * FROM " + tableName + "\nWHERE $criteria.In(\"" + key + "\", $Unsafe." + idParam + ".Values)" +} + +func (g *ComponentCodegen) normalizeMutableViewSQL(helper mutableIndexHelper, sqlText string) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return "" + } + idParam := strings.TrimSpace(g.mutableIDsParamName(helper)) + if idParam == "" { + return sqlText + } + legacy := "$" + idParam + ".Values" + normalized := "$Unsafe." + idParam + ".Values" + if strings.Contains(sqlText, legacy) && !strings.Contains(sqlText, normalized) { + sqlText = strings.ReplaceAll(sqlText, legacy, normalized) + } + return sqlText +} + +func (g *ComponentCodegen) normalizeMutableBodyReferences(body string, support *mutableComponentSupport) string { + body = strings.TrimSpace(body) + if body == "" || support == nil { + return body + } + names := []string{strings.TrimSpace(support.BodyFieldName)} + for _, helper := range support.Helpers { + if name := strings.TrimSpace(helper.ViewFieldName); name != "" { + names = append(names, name) + } + } + for _, name := range names { + if name == "" { + continue + } + pattern := regexp.MustCompile(`\$` + regexp.QuoteMeta(name) + `\b`) + body = pattern.ReplaceAllStringFunc(body, func(string) string { + return "$Unsafe." + name + }) + } + body = regexp.MustCompile(`#set\(\$([A-Za-z0-9_]+) =`).ReplaceAllString(body, `#set($$1 =`) + return body +} + +func (g *ComponentCodegen) supportBodyFieldName(helper mutableIndexHelper) string { + if g == nil { + return "" + } + inputType, err := g.mutableInputType() + if err != nil || inputType == nil { + return "" + } + if support := g.mutableSupport(inputType); support != nil { + return support.BodyFieldName + } + return "" +} + +func strconvQuote(s string) string { + return `"` + strings.ReplaceAll(s, `"`, `\"`) + `"` +} + +func escapeJSON(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"`, `\"`) + return s +} + +func (g *ComponentCodegen) mutableIDHelpers(support *mutableComponentSupport) []mutableIndexHelper { + if support == nil || len(support.Helpers) == 0 { + return nil + } + ret := append([]mutableIndexHelper{}, support.Helpers...) + sort.SliceStable(ret, func(i, j int) bool { + leftDepth := mutableRelationDepth(ret[i].RelationPath) + rightDepth := mutableRelationDepth(ret[j].RelationPath) + if leftDepth != rightDepth { + return leftDepth < rightDepth + } + return ret[i].ViewFieldName < ret[j].ViewFieldName + }) + return ret +} + +func (g *ComponentCodegen) mutableViewHelpers(support *mutableComponentSupport) []mutableIndexHelper { + if support == nil || len(support.Helpers) == 0 { + return nil + } + ret := append([]mutableIndexHelper{}, support.Helpers...) + sort.SliceStable(ret, func(i, j int) bool { + leftDepth := mutableRelationDepth(ret[i].RelationPath) + rightDepth := mutableRelationDepth(ret[j].RelationPath) + if leftDepth != rightDepth { + return leftDepth > rightDepth + } + return ret[i].ViewFieldName > ret[j].ViewFieldName + }) + return ret +} + +func mutableRelationDepth(path string) int { + path = strings.Trim(path, "/") + if path == "" { + return 0 + } + return strings.Count(path, "/") + 1 +} + +type mutableGeneratedFile struct { + Path string + Content string +} + +func (g *ComponentCodegen) mutableHelperSQLFiles(support *mutableComponentSupport) []mutableGeneratedFile { + if g == nil || g.Component == nil || support == nil { + return nil + } + packageDir := strings.TrimSpace(g.PackageDir) + if packageDir == "" && g.TypeContext != nil { + packageDir = strings.TrimSpace(g.TypeContext.PackageDir) + } + if packageDir == "" { + return nil + } + var result []mutableGeneratedFile + helperByView := map[string]mutableIndexHelper{} + helperByID := map[string]mutableIndexHelper{} + for _, helper := range support.Helpers { + helperByView[strings.TrimSpace(helper.ViewFieldName)] = helper + helperByID[g.mutableIDsParamName(helper)] = helper + } + for _, input := range g.Component.Input { + if input == nil { + continue + } + switch { + case input.In != nil && input.In.Kind == state.KindView: + helper, ok := helperByView[strings.TrimSpace(input.Name)] + if !ok { + continue + } + rel := tagURIValue(input.Tag, "sql") + content := g.mutableViewSQL(helper) + if strings.TrimSpace(content) == "" { + continue + } + result = append(result, g.mutableGeneratedSQLFiles(packageDir, rel, g.mutableHelperViewRelPath(helper), content)...) + case input.In != nil && input.In.Kind == state.KindParam: + helper, ok := helperByID[strings.TrimSpace(input.Name)] + if !ok { + continue + } + rel := tagURIValue(input.Tag, "codec") + content := g.mutableIDSQL(helper) + if strings.TrimSpace(content) == "" { + continue + } + result = append(result, g.mutableGeneratedSQLFiles(packageDir, rel, g.mutableHelperIDsRelPath(helper), content)...) + } + } + for _, helper := range support.Helpers { + if content := g.mutableViewSQL(helper); strings.TrimSpace(content) != "" { + result = append(result, g.mutableGeneratedSQLFiles(packageDir, "", g.mutableHelperViewRelPath(helper), content)...) + } + if content := g.mutableIDSQL(helper); strings.TrimSpace(content) != "" { + result = append(result, g.mutableGeneratedSQLFiles(packageDir, "", g.mutableHelperIDsRelPath(helper), content)...) + } + } + return result +} + +func (g *ComponentCodegen) mutableGeneratedSQLFiles(packageDir, primaryRel, fallbackRel, content string) []mutableGeneratedFile { + seen := map[string]struct{}{} + var result []mutableGeneratedFile + appendFile := func(rel string) { + rel = strings.TrimSpace(rel) + if rel == "" { + return + } + abs := filepath.Join(packageDir, filepath.FromSlash(rel)) + if _, ok := seen[abs]; ok { + return + } + seen[abs] = struct{}{} + result = append(result, mutableGeneratedFile{Path: abs, Content: content}) + } + appendFile(primaryRel) + appendFile(fallbackRel) + return result +} + +func tagURIValue(tag, key string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + needle := key + `:"` + start := strings.Index(tag, needle) + if start == -1 { + return "" + } + rest := tag[start+len(needle):] + end := strings.Index(rest, `"`) + if end == -1 { + return "" + } + value := rest[:end] + if idx := strings.Index(value, "uri="); idx >= 0 { + value = value[idx+4:] + if cut := strings.IndexAny(value, ", "); cut >= 0 { + value = value[:cut] + } + } + return strings.TrimSpace(value) +} + +func (g *ComponentCodegen) mutableIDsParamName(helper mutableIndexHelper) string { + name := "Cur" + support := g.mutableSupportMust() + if support != nil { + name += support.BodyFieldName + } + if helper.RelationPath != "" { + name += strings.ReplaceAll(helper.RelationPath, "/", "") + } + name += helper.KeyFieldName + return name +} + +func (g *ComponentCodegen) mutableHelperViewRelPath(helper mutableIndexHelper) string { + componentDir := text.CaseFormatUpperCamel.Format(strings.TrimSpace(g.componentName()), text.CaseFormatLowerUnderscore) + name := text.CaseFormatUpperCamel.Format(strings.TrimSpace(helper.ViewFieldName), text.CaseFormatLowerUnderscore) + if componentDir == "" || name == "" { + return "" + } + return filepath.ToSlash(filepath.Join(componentDir, name+".sql")) +} + +func (g *ComponentCodegen) mutableHelperIDsRelPath(helper mutableIndexHelper) string { + componentDir := text.CaseFormatUpperCamel.Format(strings.TrimSpace(g.componentName()), text.CaseFormatLowerUnderscore) + name := text.CaseFormatUpperCamel.Format(strings.TrimSpace(g.mutableIDsParamName(helper)), text.CaseFormatLowerUnderscore) + if componentDir == "" || name == "" { + return "" + } + return filepath.ToSlash(filepath.Join(componentDir, name+".sql")) +} + +func (g *ComponentCodegen) mutableSupportMust() *mutableComponentSupport { + inputType, err := g.mutableInputType() + if err != nil || inputType == nil { + return nil + } + return g.mutableSupport(inputType) +} + +func (g *ComponentCodegen) mutableInputType() (reflect.Type, error) { + if g == nil || g.Component == nil { + return nil, nil + } + params := g.codegenInputParameters() + opts := []state.ReflectOption{state.WithSetMarker(), state.WithTypeName(g.inputTypeName(g.componentName()))} + if g.componentUsesVelty() { + opts = append(opts, state.WithVelty(true)) + } + pkgPath := "" + if g.TypeContext != nil { + pkgPath = g.TypeContext.PackagePath + } + return params.ReflectType(pkgPath, g.componentLookupType(pkgPath), opts...) +} + +func (g *ComponentCodegen) buildMutableVeltyBlock(inputType reflect.Type, support *mutableComponentSupport) (shapeast.Block, error) { + var block shapeast.Block + bodyField, ok := inputType.FieldByName(support.BodyFieldName) + if !ok { + return nil, nil + } + bodyItemType, _ := mutableBodyItemType(bodyField.Type) + if bodyItemType == nil { + return nil, nil + } + bodyKeyField, ok := lookupGeneratedIndexField(bodyItemType) + if !ok { + return nil, nil + } + bodyTable := g.mutableBodyTableName(support, bodyItemType) + if bodyTable == "" { + return nil, nil + } + + g.appendMutableSequence(&block, shapeast.NewIdent(support.BodyFieldName), "", bodyItemType, bodyTable, bodyKeyField) + g.appendMutableRelationSequences(&block, shapeast.NewIdent(support.BodyFieldName), "", bodyItemType) + + for _, helper := range support.Helpers { + block.Append(shapeast.NewAssign( + g.mutableHelperMapHolder(helper), + shapeast.NewCallExpr(shapeast.NewIdent(helper.ViewFieldName), "IndexBy", shapeast.NewQuotedLiteral(helper.KeyFieldName)), + )) + } + if len(support.Helpers) > 0 { + block.AppendEmptyLine() + } + + rootHelper := support.rootHelper() + bodyExpr := shapeast.NewIdent(support.BodyFieldName) + if g.mutableBodyMany(bodyField, support) { + recordName := mutableRecordName(support.BodyFieldName) + forEach := shapeast.NewForEach(shapeast.NewIdent(recordName), bodyExpr, shapeast.Block{}) + g.appendMutableWriteLogic(&forEach.Body, shapeast.NewIdent(recordName), "", bodyItemType, bodyTable, support, rootHelper, bodyKeyField) + block.Append(forEach) + return block, nil + } + + g.appendMutableWriteLogic(&block, bodyExpr, "", bodyItemType, bodyTable, support, rootHelper, bodyKeyField) + return block, nil +} + +func (g *ComponentCodegen) mutableHelperMapHolder(helper mutableIndexHelper) shapeast.Expression { + return shapeast.NewIdent(helper.MapFieldName) +} + +func mutableItemExprIsPointer(itemTypeExpr string) bool { + itemTypeExpr = strings.TrimSpace(itemTypeExpr) + return strings.HasPrefix(itemTypeExpr, "*") || strings.HasPrefix(itemTypeExpr, "[]*") +} + +func (g *ComponentCodegen) mutableBodyMany(bodyField reflect.StructField, support *mutableComponentSupport) bool { + if support != nil && support.BodyMany { + return true + } + if g != nil && g.Component != nil { + for _, input := range g.Component.Input { + if input == nil || input.In == nil || input.In.Kind != state.KindRequestBody { + continue + } + if input.Schema != nil && input.Schema.Cardinality != "" { + return input.Schema.Cardinality == state.Many + } + break + } + } + _, many := mutableBodyItemType(bodyField.Type) + return many +} + +func (g *ComponentCodegen) appendMutableWriteLogic(block *shapeast.Block, recordExpr *shapeast.Ident, logicalPath string, recordType reflect.Type, tableName string, support *mutableComponentSupport, rootHelper *mutableIndexHelper, keyField reflect.StructField) { + method := strings.ToUpper(strings.TrimSpace(g.Component.Method)) + hasCurrent := rootHelper != nil + writeUpdate := method == "PATCH" || method == "PUT" + writeInsert := method == "PATCH" || method == "POST" + keyFieldName := keyField.Name + + if hasCurrent && writeUpdate { + hasKey := shapeast.NewBinary( + shapeast.NewCallExpr(shapeast.NewIdent(rootHelper.MapFieldName), "HasKey", shapeast.NewIdent(recordExpr.Name+"."+keyFieldName)), + "==", + shapeast.NewLiteral("true"), + ) + condition := shapeast.NewCondition(hasKey, shapeast.Block{}, nil) + condition.IFBlock.Append(shapeast.NewStatementExpression(shapeast.NewTerminatorExpression(shapeast.NewCallExpr( + shapeast.NewIdent("sql"), "Update", recordExpr, shapeast.NewQuotedLiteral(tableName), + )))) + if writeInsert { + condition.ElseBlock = shapeast.Block{ + shapeast.NewStatementExpression(shapeast.NewTerminatorExpression(shapeast.NewCallExpr( + shapeast.NewIdent("sql"), "Insert", recordExpr, shapeast.NewQuotedLiteral(tableName), + ))), + } + } + block.Append(condition) + g.appendChildMutableWriteLogic(block, recordExpr, logicalPath, recordType, support) + return + } + + if writeInsert { + block.Append(shapeast.NewStatementExpression(shapeast.NewTerminatorExpression(shapeast.NewCallExpr( + shapeast.NewIdent("sql"), "Insert", recordExpr, shapeast.NewQuotedLiteral(tableName), + )))) + } + g.appendChildMutableWriteLogic(block, recordExpr, logicalPath, recordType, support) +} + +func (g *ComponentCodegen) appendMutableSequence(block *shapeast.Block, bodyExpr *shapeast.Ident, path string, itemType reflect.Type, tableName string, keyField reflect.StructField) { + if !g.needsMutableSequence(keyField.Type) { + return + } + block.Append(shapeast.NewStatementExpression(shapeast.NewCallExpr( + shapeast.NewIdent("sequencer"), + "Allocate", + shapeast.NewQuotedLiteral(tableName), + bodyExpr, + shapeast.NewQuotedLiteral(mutableSequencePath(path, keyField.Name)), + ))) + block.AppendEmptyLine() +} + +func mutableSequencePath(path, key string) string { + path = strings.Trim(path, "/") + if path == "" { + return key + } + return path + "/" + key +} + +func (g *ComponentCodegen) appendChildMutableWriteLogic(block *shapeast.Block, parentExpr *shapeast.Ident, logicalPath string, parentType reflect.Type, support *mutableComponentSupport) { + parentType = unwrapNamedStructType(parentType) + if parentType == nil { + return + } + for i := 0; i < parentType.NumField(); i++ { + field := parentType.Field(i) + if !isMutableRelationField(field) { + continue + } + childItemType, childMany := mutableBodyItemType(field.Type) + if childItemType == nil { + continue + } + childKeyField, ok := lookupGeneratedIndexField(childItemType) + if !ok { + continue + } + childTable := mutableRelationTableName(field) + if childTable == "" { + childTable = tableNameFromType(childItemType.Name()) + } + if childTable == "" { + continue + } + childPath := mutableSequencePath(logicalPath, field.Name) + assignments := mutableRelationAssignments(field) + childHelper := support.findHelper(field.Name, field) + + childExprName := field.Name + if childMany { + recordName := mutableRecordName(field.Name) + forEach := shapeast.NewForEach(shapeast.NewIdent(recordName), shapeast.NewIdent(parentExpr.Name+"."+field.Name), shapeast.Block{}) + appendMutableRelationAssignments(&forEach.Body, shapeast.NewIdent(recordName), parentExpr, assignments, parentType, childItemType) + g.appendMutableWriteLogic(&forEach.Body, shapeast.NewIdent(recordName), childPath, childItemType, childTable, support, childHelper, childKeyField) + block.AppendEmptyLine() + block.Append(forEach) + continue + } + condition := shapeast.NewCondition(shapeast.NewIdent(parentExpr.Name+"."+childExprName), shapeast.Block{}, nil) + childExpr := shapeast.NewIdent(parentExpr.Name + "." + childExprName) + appendMutableRelationAssignments(&condition.IFBlock, childExpr, parentExpr, assignments, parentType, childItemType) + g.appendMutableWriteLogic(&condition.IFBlock, childExpr, childPath, childItemType, childTable, support, childHelper, childKeyField) + block.AppendEmptyLine() + block.Append(condition) + } +} + +func (g *ComponentCodegen) appendMutableRelationSequences(block *shapeast.Block, rootExpr *shapeast.Ident, logicalPath string, parentType reflect.Type) { + parentType = unwrapNamedStructType(parentType) + if parentType == nil { + return + } + for i := 0; i < parentType.NumField(); i++ { + field := parentType.Field(i) + if !isMutableRelationField(field) { + continue + } + childItemType, _ := mutableBodyItemType(field.Type) + if childItemType == nil { + continue + } + childKeyField, ok := lookupGeneratedIndexField(childItemType) + if !ok { + continue + } + childTable := mutableRelationTableName(field) + if childTable == "" { + childTable = tableNameFromType(childItemType.Name()) + } + if childTable == "" { + continue + } + childPath := mutableSequencePath(logicalPath, field.Name) + g.appendMutableSequence(block, rootExpr, childPath, childItemType, childTable, childKeyField) + g.appendMutableRelationSequences(block, rootExpr, childPath, childItemType) + } +} + +func (g *ComponentCodegen) mutableBodyTableName(support *mutableComponentSupport, bodyItemType reflect.Type) string { + if support != nil { + if rootHelper := support.rootHelper(); rootHelper != nil { + if name := g.mutableTableFromViewState(rootHelper.ViewParamName); name != "" { + return name + } + } + } + if g != nil && g.Resource != nil && g.Component != nil { + rootViewName := strings.TrimSpace(g.Component.RootView) + if rootViewName != "" { + if rootView, err := g.Resource.View(rootViewName); err == nil && rootView != nil { + if rootView.Table != "" { + return strings.TrimSpace(rootView.Table) + } + if rootView.Template != nil { + if name := tableNameFromSQL(rootView.Template.Source); name != "" { + return name + } + } + if name := tableNameFromType(rootView.Name); name != "" { + return name + } + } + } + } + if name := tableNameFromType(bodyItemType.Name()); name != "" { + return name + } + return "" +} + +func (g *ComponentCodegen) mutableTableFromViewState(viewParamName string) string { + if g == nil || g.Resource == nil { + return "" + } + for _, input := range g.Component.Input { + if input == nil || strings.TrimSpace(input.Name) != strings.TrimSpace(viewParamName) { + continue + } + viewName := mutableViewNameFromTag(input.Tag) + if viewName == "" { + viewName = strings.TrimSpace(input.Name) + } + if viewName == "" { + continue + } + aView, err := g.Resource.View(viewName) + if err != nil || aView == nil || aView.Template == nil { + continue + } + if table := tableNameFromSQL(aView.Template.Source); table != "" { + return table + } + } + return "" +} + +func mutableViewNameFromTag(tag string) string { + if tag == "" { + return "" + } + idx := strings.Index(tag, `view:"`) + if idx == -1 { + return "" + } + rest := tag[idx+len(`view:"`):] + end := strings.Index(rest, `"`) + if end == -1 { + return "" + } + return strings.TrimSpace(rest[:end]) +} + +func tableNameFromSQL(sql string) string { + fields := strings.Fields(sql) + for i := 0; i < len(fields)-1; i++ { + if strings.EqualFold(fields[i], "FROM") { + candidate := strings.TrimSpace(fields[i+1]) + candidate = strings.Trim(candidate, "`()") + candidate = strings.TrimRight(candidate, ",;") + if candidate != "" { + return candidate + } + } + } + return "" +} + +func tableNameFromType(typeName string) string { + typeName = strings.TrimSpace(typeName) + if typeName == "" { + return "" + } + return text.CaseFormatUpperCamel.Format(typeName, text.CaseFormatUpperUnderscore) +} + +func mutableBodyItemType(rType reflect.Type) (reflect.Type, bool) { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType == nil { + return nil, false + } + switch rType.Kind() { + case reflect.Slice, reflect.Array: + return unwrapNamedStructType(rType.Elem()), true + case reflect.Struct: + return rType, false + default: + return nil, false + } +} + +func mutableRecordName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "Rec" + } + return "Rec" + name +} + +func isMutableRelationField(field reflect.StructField) bool { + return strings.Contains(field.Tag.Get("view"), "table=") || field.Tag.Get("on") != "" +} + +func mutableRelationTableName(field reflect.StructField) string { + viewTag := field.Tag.Get("view") + for _, part := range strings.Split(viewTag, ",") { + part = strings.TrimSpace(part) + if strings.HasPrefix(strings.ToLower(part), "table=") { + return strings.TrimSpace(strings.TrimPrefix(part, "table=")) + } + } + return "" +} + +type mutableRelationAssignment struct { + ParentField string + ChildField string +} + +func mutableRelationAssignments(field reflect.StructField) []mutableRelationAssignment { + raw := strings.TrimSpace(field.Tag.Get("on")) + if raw == "" { + return nil + } + var result []mutableRelationAssignment + for _, expr := range strings.Split(raw, ",") { + expr = strings.TrimSpace(expr) + if expr == "" { + continue + } + parts := strings.Split(expr, "=") + if len(parts) != 2 { + continue + } + parentField := strings.TrimSpace(strings.Split(strings.TrimSpace(parts[0]), ":")[0]) + childField := strings.TrimSpace(strings.Split(strings.TrimSpace(parts[1]), ":")[0]) + if parentField == "" || childField == "" { + continue + } + result = append(result, mutableRelationAssignment{ParentField: parentField, ChildField: childField}) + } + return result +} + +func appendMutableRelationAssignments(block *shapeast.Block, childExpr, parentExpr *shapeast.Ident, assignments []mutableRelationAssignment, parentType, childType reflect.Type) { + for _, assignment := range assignments { + src := shapeast.Expression(shapeast.NewIdent(parentExpr.Name + "." + assignment.ParentField)) + var childFieldType, parentFieldType reflect.Type + if childType != nil { + if childField, ok := childType.FieldByName(assignment.ChildField); ok { + childFieldType = childField.Type + } + } + if parentType != nil { + if parentField, ok := parentType.FieldByName(assignment.ParentField); ok { + parentFieldType = parentField.Type + } + } + if childFieldType != nil && parentFieldType != nil { + childPtr := childFieldType.Kind() == reflect.Ptr + parentPtr := parentFieldType.Kind() == reflect.Ptr + if childPtr && !parentPtr { + src = shapeast.NewRefExpression(src) + } else if !childPtr && parentPtr { + src = shapeast.NewDerefExpression(src) + } + } + block.Append(shapeast.NewAssign(shapeast.NewIdent(childExpr.Name+"."+assignment.ChildField), src)) + } +} + +func (s *mutableComponentSupport) rootHelper() *mutableIndexHelper { + if s == nil { + return nil + } + want := "Cur" + s.BodyFieldName + for i := range s.Helpers { + if s.Helpers[i].ViewFieldName == want { + return &s.Helpers[i] + } + } + if len(s.Helpers) == 1 { + return &s.Helpers[0] + } + return nil +} + +func (s *mutableComponentSupport) findHelper(fieldName string, field reflect.StructField) *mutableIndexHelper { + if s == nil { + return nil + } + itemExpr, _ := collectionItemType(field) + wantSuffix := strings.TrimSpace(fieldName) + for i := range s.Helpers { + helper := &s.Helpers[i] + if itemExpr != "" && strings.EqualFold(strings.TrimSpace(helper.ItemTypeExpr), strings.TrimSpace(itemExpr)) { + return helper + } + if wantSuffix != "" && strings.HasSuffix(strings.TrimSpace(helper.ViewFieldName), wantSuffix) { + return helper + } + } + return nil +} + +func (g *ComponentCodegen) needsMutableSequence(keyType reflect.Type) bool { + if g == nil || g.Component == nil { + return false + } + method := strings.ToUpper(strings.TrimSpace(g.Component.Method)) + if method != "PATCH" && method != "POST" { + return false + } + for keyType != nil && keyType.Kind() == reflect.Ptr { + keyType = keyType.Elem() + } + if keyType == nil { + return false + } + switch keyType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + default: + return false + } +} + +func mutableVeltyOutputUsesBody(aView *view.View) bool { + return aView != nil +} diff --git a/repository/shape/xgen/mutable_helpers.go b/repository/shape/xgen/mutable_helpers.go new file mode 100644 index 000000000..c99e1c021 --- /dev/null +++ b/repository/shape/xgen/mutable_helpers.go @@ -0,0 +1,464 @@ +package xgen + +import ( + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/view/state" +) + +type mutableComponentSupport struct { + BodyFieldName string + BodyTypeName string + BodyMany bool + Helpers []mutableIndexHelper +} + +type mutableIndexHelper struct { + ViewParamName string + ViewFieldName string + TypeName string + MapFieldName string + ItemTypeExpr string + MapTypeExpr string + KeyFieldName string + KeyFieldType string + KeyReadExpr string + NeedNilCheck bool + ItemIsPointer bool + RelationPath string + ItemStruct reflect.Type +} + +func (g *ComponentCodegen) mutableSupport(inputType reflect.Type) *mutableComponentSupport { + if !g.componentUsesVelty() || g.componentUsesHandler() || g.Component == nil || inputType == nil { + return nil + } + bodyFieldName := "" + for _, input := range g.Component.Input { + if input == nil || input.In == nil || input.In.Kind != state.KindRequestBody { + continue + } + bodyFieldName = exportedCodegenParamName(input.Name) + break + } + if bodyFieldName == "" { + return nil + } + + support := &mutableComponentSupport{BodyFieldName: bodyFieldName} + for _, input := range g.Component.Input { + if input == nil || input.In == nil || input.In.Kind != state.KindRequestBody { + continue + } + if input.Schema != nil { + if bodyTypeName := strings.TrimSpace(input.Schema.Name); bodyTypeName != "" { + support.BodyTypeName = bodyTypeName + } + support.BodyMany = input.Schema.Cardinality == state.Many + } + break + } + for _, input := range g.mutableHelperParametersForCodegen() { + if input == nil || input.In == nil || input.In.Kind != state.KindView { + continue + } + helper, ok := g.mutableIndexHelper(inputType, bodyFieldName, input) + if !ok { + continue + } + support.Helpers = append(support.Helpers, helper) + } + if len(support.Helpers) == 0 { + return nil + } + return support +} + +func (g *ComponentCodegen) mutableIndexHelper(inputType reflect.Type, bodyFieldName string, param *state.Parameter) (mutableIndexHelper, bool) { + fieldName := exportedCodegenParamName(param.Name) + if fieldName == "" { + return mutableIndexHelper{}, false + } + viewField, ok := inputType.FieldByName(fieldName) + if !ok { + return mutableIndexHelper{}, false + } + itemTypeExpr, itemStructType := collectionItemType(viewField) + if itemStructType == nil { + return mutableIndexHelper{}, false + } + itemIsPointer := viewField.Type.Kind() == reflect.Slice && viewField.Type.Elem().Kind() == reflect.Ptr + if namedItemTypeExpr := generatedMutableItemTypeExpr(param, itemIsPointer); namedItemTypeExpr != "" { + itemTypeExpr = namedItemTypeExpr + } + keyFieldName := "" + keyType := reflect.Type(nil) + keyReadExpr := "" + needNilCheck := false + if g != nil { + if g.Resource != nil { + if inputView := lookupInputView(g.Resource, strings.TrimSpace(param.Name)); inputView != nil { + if _, resolvedFieldName, resolvedType, ok := g.generatedIndexColumn(g.semanticView(inputView)); ok { + keyFieldName = resolvedFieldName + keyType = resolvedType + keyReadExpr = fmt.Sprintf("item.%s", keyFieldName) + if keyType.Kind() == reflect.Ptr { + needNilCheck = true + keyReadExpr = "*" + keyReadExpr + keyType = keyType.Elem() + } + } + } + } + if keyType == nil { + if resourceType := g.resourceViewStructType(strings.TrimSpace(param.Name)); resourceType != nil { + if keyField, ok := lookupGeneratedIndexField(resourceType); ok { + keyFieldName = keyField.Name + keyType = keyField.Type + keyReadExpr = fmt.Sprintf("item.%s", keyFieldName) + if keyType.Kind() == reflect.Ptr { + needNilCheck = true + keyReadExpr = "*" + keyReadExpr + keyType = keyType.Elem() + } + } + } + } + } + if keyType == nil { + keyField, ok := lookupGeneratedIndexField(itemStructType) + if !ok { + return mutableIndexHelper{}, false + } + keyFieldName = keyField.Name + keyType = keyField.Type + keyReadExpr = fmt.Sprintf("item.%s", keyFieldName) + if keyType.Kind() == reflect.Ptr { + needNilCheck = true + keyReadExpr = "*" + keyReadExpr + keyType = keyType.Elem() + } + } + keyTypeExpr := sourceTypeExpr(keyType, "") + if keyTypeExpr == "" { + return mutableIndexHelper{}, false + } + mapFieldName := fieldName + "By" + keyFieldName + if _, exists := inputType.FieldByName(mapFieldName); exists { + return mutableIndexHelper{}, false + } + return mutableIndexHelper{ + ViewParamName: strings.TrimSpace(param.Name), + ViewFieldName: fieldName, + TypeName: func() string { + if param.Schema == nil { + return "" + } + return strings.TrimSpace(param.Schema.Name) + }(), + MapFieldName: mapFieldName, + ItemTypeExpr: itemTypeExpr, + MapTypeExpr: fmt.Sprintf("map[%s]%s", keyTypeExpr, itemTypeExpr), + KeyFieldName: keyFieldName, + KeyFieldType: keyTypeExpr, + KeyReadExpr: keyReadExpr, + NeedNilCheck: needNilCheck, + ItemIsPointer: itemIsPointer, + RelationPath: mutableRelationPath(inputType, itemStructType, bodyFieldName), + ItemStruct: itemStructType, + }, true +} + +func generatedMutableItemTypeExpr(param *state.Parameter, itemIsPointer bool) string { + if param == nil || param.Schema == nil { + return "" + } + typeName := strings.TrimSpace(param.Schema.Name) + if typeName == "" { + return "" + } + if itemIsPointer { + return "*" + typeName + } + return typeName +} + +func (g *ComponentCodegen) mutableHelperParametersForCodegen() []*state.Parameter { + params := g.codegenInputParameters() + if len(params) == 0 { + return nil + } + result := make([]*state.Parameter, 0, len(params)) + for _, item := range params { + if item == nil { + continue + } + result = append(result, item) + } + return result +} + +func mutableRelationPath(inputType reflect.Type, itemType reflect.Type, bodyFieldName string) string { + if inputType == nil || itemType == nil || bodyFieldName == "" { + return "" + } + bodyField, ok := inputType.FieldByName(bodyFieldName) + if !ok { + return "" + } + rootType, _ := mutableBodyItemType(bodyField.Type) + if rootType == nil { + return "" + } + if sameNamedStructType(rootType, itemType) { + return "" + } + return lookupMutableRelationPath(rootType, itemType, "") +} + +func lookupMutableRelationPath(parentType reflect.Type, itemType reflect.Type, prefix string) string { + parentType = unwrapNamedStructType(parentType) + itemType = unwrapNamedStructType(itemType) + if parentType == nil || itemType == nil { + return "" + } + for i := 0; i < parentType.NumField(); i++ { + field := parentType.Field(i) + if !isMutableRelationField(field) { + continue + } + childType, _ := mutableBodyItemType(field.Type) + if childType == nil { + continue + } + current := field.Name + if prefix != "" { + current = prefix + "/" + current + } + if sameNamedStructType(childType, itemType) { + return current + } + if nested := lookupMutableRelationPath(childType, itemType, current); nested != "" { + return nested + } + } + return "" +} + +func sameNamedStructType(left, right reflect.Type) bool { + left = unwrapNamedStructType(left) + right = unwrapNamedStructType(right) + if left == nil || right == nil { + return false + } + if left == right { + return true + } + if left.Name() != "" && right.Name() != "" && left.Name() == right.Name() && left.PkgPath() == right.PkgPath() { + return true + } + return false +} + +func (s *mutableComponentSupport) renderInputFields(builder *strings.Builder) { + if s == nil { + return + } + for _, helper := range s.Helpers { + builder.WriteString(fmt.Sprintf("\t%s %s `json:\"-\"`\n", helper.MapFieldName, helper.MapTypeExpr)) + } +} + +func (s *mutableComponentSupport) renderInputInit(inputTypeName, outputTypeName string) string { + if s == nil { + return "" + } + if strings.TrimSpace(inputTypeName) == "" { + inputTypeName = "Input" + } + if strings.TrimSpace(outputTypeName) == "" { + outputTypeName = "Output" + } + var builder strings.Builder + builder.WriteString(fmt.Sprintf("func (i *%s) Init(ctx context.Context, sess handler.Session, output *%s) error {\n", inputTypeName, outputTypeName)) + builder.WriteString("\tif err := sess.Stater().Bind(ctx, i); err != nil {\n") + builder.WriteString("\t\treturn err\n") + builder.WriteString("\t}\n") + builder.WriteString("\ti.indexSlice()\n") + builder.WriteString("\treturn nil\n") + builder.WriteString("}\n\n") + builder.WriteString(fmt.Sprintf("func (i *%s) indexSlice() {\n", inputTypeName)) + for _, helper := range s.Helpers { + builder.WriteString(fmt.Sprintf("\ti.%s = make(%s, len(i.%s))\n", helper.MapFieldName, helper.MapTypeExpr, helper.ViewFieldName)) + builder.WriteString(fmt.Sprintf("\tfor _, item := range i.%s {\n", helper.ViewFieldName)) + if helper.ItemIsPointer { + builder.WriteString("\t\tif item == nil {\n") + builder.WriteString("\t\t\tcontinue\n") + builder.WriteString("\t\t}\n") + } + if helper.NeedNilCheck { + builder.WriteString(fmt.Sprintf("\t\tif item.%s == nil {\n", helper.KeyFieldName)) + builder.WriteString("\t\t\tcontinue\n") + builder.WriteString("\t\t}\n") + } + builder.WriteString(fmt.Sprintf("\t\ti.%s[%s] = item\n", helper.MapFieldName, helper.KeyReadExpr)) + builder.WriteString("\t}\n") + } + builder.WriteString("}\n") + return builder.String() +} + +func (s *mutableComponentSupport) renderInputValidate(inputTypeName, outputTypeName string) string { + if s == nil { + return "" + } + if strings.TrimSpace(inputTypeName) == "" { + inputTypeName = "Input" + } + if strings.TrimSpace(outputTypeName) == "" { + outputTypeName = "Output" + } + var builder strings.Builder + builder.WriteString(fmt.Sprintf("func (i *%s) Validate(ctx context.Context, sess handler.Session, output *%s) error {\n", inputTypeName, outputTypeName)) + builder.WriteString("\taValidator := sess.Validator()\n") + builder.WriteString("\tsessionDb, err := sess.Db()\n") + builder.WriteString("\tif err != nil {\n") + builder.WriteString("\t\treturn err\n") + builder.WriteString("\t}\n") + builder.WriteString("\tdb, err := sessionDb.Db(ctx)\n") + builder.WriteString("\tif err != nil {\n") + builder.WriteString("\t\treturn err\n") + builder.WriteString("\t}\n") + builder.WriteString("\tvar options = []validator.Option{\n") + builder.WriteString(fmt.Sprintf("\t\tvalidator.WithLocation(%q),\n", s.BodyFieldName)) + builder.WriteString("\t\tvalidator.WithDB(db),\n") + builder.WriteString("\t\tvalidator.WithUnique(true),\n") + builder.WriteString("\t\tvalidator.WithRefCheck(true),\n") + builder.WriteString("\t\tvalidator.WithCanUseMarkerProvider(i.canUseMarkerProvider),\n") + builder.WriteString("\t}\n") + builder.WriteString("\tvalidation := validator.NewValidation()\n") + builder.WriteString(fmt.Sprintf("\terr = i.validate(ctx, aValidator, validation, options, i.%s)\n", s.BodyFieldName)) + builder.WriteString("\toutput.Violations = append(output.Violations, validation.Violations...)\n") + builder.WriteString("\tif err == nil && len(validation.Violations) > 0 {\n") + builder.WriteString("\t\tvalidation.Violations.Sort()\n") + builder.WriteString("\t}\n") + builder.WriteString("\treturn err\n") + builder.WriteString("}\n\n") + builder.WriteString(fmt.Sprintf("func (i *%s) validate(ctx context.Context, aValidator *validator.Service, validation *validator.Validation, options []validator.Option, value interface{}) error {\n", inputTypeName)) + builder.WriteString("\t_, err := aValidator.Validate(ctx, value, append(options, validator.WithValidation(validation))...)\n") + builder.WriteString("\tif err != nil {\n") + builder.WriteString("\t\treturn err\n") + builder.WriteString("\t}\n") + builder.WriteString("\treturn nil\n") + builder.WriteString("}\n\n") + builder.WriteString(fmt.Sprintf("func (i *%s) canUseMarkerProvider(v interface{}) bool {\n", inputTypeName)) + builder.WriteString("\tswitch actual := v.(type) {\n") + for _, helper := range s.Helpers { + builder.WriteString(fmt.Sprintf("\tcase %s:\n", helper.ItemTypeExpr)) + if helper.NeedNilCheck { + builder.WriteString(fmt.Sprintf("\t\tif actual.%s == nil {\n", helper.KeyFieldName)) + builder.WriteString("\t\t\treturn false\n") + builder.WriteString("\t\t}\n") + } + actualKey := fmt.Sprintf("actual.%s", helper.KeyFieldName) + if helper.NeedNilCheck { + actualKey = "*" + actualKey + } + builder.WriteString(fmt.Sprintf("\t\t_, ok := i.%s[%s]\n", helper.MapFieldName, actualKey)) + builder.WriteString("\t\treturn ok\n") + } + builder.WriteString("\tdefault:\n") + builder.WriteString("\t\treturn true\n") + builder.WriteString("\t}\n") + builder.WriteString("}\n") + return builder.String() +} + +func collectionItemType(field reflect.StructField) (string, reflect.Type) { + rType := field.Type + expr := sourceFieldTypeExpr(field) + if expr == "" || rType == nil { + return "", nil + } + switch rType.Kind() { + case reflect.Slice, reflect.Array: + return strings.TrimPrefix(expr, "[]"), unwrapNamedStructType(rType.Elem()) + default: + return "", nil + } +} + +func unwrapNamedStructType(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType == nil || rType.Kind() != reflect.Struct { + return nil + } + return rType +} + +func lookupGeneratedIndexField(structType reflect.Type) (reflect.StructField, bool) { + if structType == nil || structType.Kind() != reflect.Struct { + return reflect.StructField{}, false + } + if field, ok := structType.FieldByName("Id"); ok { + return field, true + } + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if generatedSQLXFieldName(field.Tag.Get("sqlx")) == "ID" { + return field, true + } + } + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if strings.Contains(strings.ToLower(field.Tag.Get("sqlx")), "primarykey") { + return field, true + } + } + return reflect.StructField{}, false +} + +func sourceTypeExpr(rType reflect.Type, typeName string) string { + if rType == nil { + return typeName + } + switch rType.Kind() { + case reflect.Ptr: + return "*" + sourceTypeExpr(rType.Elem(), typeName) + case reflect.Slice: + return "[]" + sourceTypeExpr(rType.Elem(), typeName) + case reflect.Array: + return fmt.Sprintf("[%d]%s", rType.Len(), sourceTypeExpr(rType.Elem(), typeName)) + case reflect.Map: + return "map[" + sourceTypeExpr(rType.Key(), "") + "]" + sourceTypeExpr(rType.Elem(), typeName) + default: + if typeName != "" { + return typeName + } + return rType.String() + } +} + +func generatedSQLXFieldName(tag string) string { + tag = strings.TrimSpace(tag) + if tag == "" { + return "" + } + for _, part := range strings.Split(tag, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if strings.HasPrefix(part, "name=") { + return strings.TrimSpace(strings.TrimPrefix(part, "name=")) + } + if !strings.Contains(part, "=") { + return part + } + } + return "" +} diff --git a/repository/shape/xgen/repro_xgen_shapefragment_test.go b/repository/shape/xgen/repro_xgen_shapefragment_test.go new file mode 100644 index 000000000..5837a50e3 --- /dev/null +++ b/repository/shape/xgen/repro_xgen_shapefragment_test.go @@ -0,0 +1,35 @@ +package xgen + +import ( + shapeload "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" + "path/filepath" + "reflect" + "testing" +) + +func TestReproShapeFragment(t *testing.T) { + projectDir := t.TempDir() + packageDir := filepath.Join(projectDir, "shape", "dev", "vendor", "details") + component := &shapeload.Component{Method: "GET", URI: "/v1/api/shape/dev/vendors/{vendorID}", RootView: "vendor", Output: []*plan.State{{Parameter: state.Parameter{Name: "Data", In: state.NewOutputLocation("view"), Schema: &state.Schema{Cardinality: state.Many}}}}} + resource := view.EmptyResource() + resource.Views = append(resource.Views, &view.View{Name: "vendor", Schema: &state.Schema{Name: "VendorView", DataType: "*VendorView", Cardinality: state.Many}}) + resource.Views[0].Schema.SetType(reflect.TypeOf([]struct { + ID int + Products []*struct{ ID int } `view:",table=PRODUCT" json:",omitempty" sqlx:"-"` + }{})) + ctx := &typectx.Context{PackageDir: packageDir, PackageName: "details", PackagePath: "github.com/acme/project/shape/dev/vendor/details"} + codegen := &ComponentCodegen{Component: component, Resource: resource, TypeContext: ctx, ProjectDir: projectDir, WithEmbed: false, WithContract: false} + frag, err := codegen.generateShapeFragment(projectDir, packageDir, "details", ctx.PackagePath) + if err != nil { + t.Fatalf("generateShapeFragment err: %v", err) + } + if frag == nil { + t.Fatalf("nil fragment") + } + t.Logf("types=%v", frag.Types) + t.Logf("decls=%s", frag.TypeDecls) +} diff --git a/repository/shape/xgen/resource.go b/repository/shape/xgen/resource.go new file mode 100644 index 000000000..760d17a86 --- /dev/null +++ b/repository/shape/xgen/resource.go @@ -0,0 +1,95 @@ +package xgen + +import ( + "fmt" + "strings" + + "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/view" +) + +// GenerateFromResource produces Go structs directly from an in-memory view.Resource +// without YAML roundtrip. Uses real columns from DB discovery when available. +func GenerateFromResource(resource *view.Resource, typeCtx *typectx.Context, cfg *Config) (*Result, error) { + if resource == nil { + return nil, fmt.Errorf("shape xgen: nil resource") + } + doc := resourceToShapeDocument(resource, typeCtx) + return GenerateFromDQLShape(doc, cfg) +} + +// resourceToShapeDocument converts an in-memory view.Resource into a shape.Document +// that xgen can process. This avoids the YAML marshal/unmarshal roundtrip. +func resourceToShapeDocument(resource *view.Resource, typeCtx *typectx.Context) *shape.Document { + root := map[string]any{} + + // Build Resource.Views from in-memory views + var views []any + for _, aView := range resource.Views { + if aView == nil { + continue + } + viewMap := map[string]any{ + "Name": aView.Name, + "Table": aView.Table, + "Mode": string(aView.Mode), + } + if aView.Module != "" { + viewMap["Module"] = aView.Module + } + // Schema + if aView.Schema != nil { + schema := map[string]any{} + if aView.Schema.Name != "" { + schema["Name"] = aView.Schema.Name + } + if aView.Schema.DataType != "" { + schema["DataType"] = aView.Schema.DataType + } + if aView.Schema.Cardinality != "" { + schema["Cardinality"] = string(aView.Schema.Cardinality) + } + viewMap["Schema"] = schema + } + // Columns — this is the key: real columns from DB discovery + if len(aView.Columns) > 0 { + var columns []any + for _, col := range aView.Columns { + if col == nil { + continue + } + colMap := map[string]any{ + "Name": col.Name, + "DataType": col.DataType, + } + if col.Tag != "" { + colMap["Tag"] = col.Tag + } + if col.Nullable { + colMap["Nullable"] = true + } + columns = append(columns, colMap) + } + viewMap["Columns"] = columns + } + views = append(views, viewMap) + } + root["Resource"] = map[string]any{"Views": views} + return &shape.Document{ + Root: root, + TypeContext: typeCtx, + } +} + +// columnDataType returns Go type name for a view.Column. +func columnDataType(col *view.Column) string { + if col.DataType != "" { + return col.DataType + } + rType := col.ColumnType() + if rType == nil { + return "string" + } + return strings.TrimPrefix(rType.String(), "*") +} diff --git a/service.go b/service.go index 3f7e9b28b..2c984ddf3 100644 --- a/service.go +++ b/service.go @@ -587,7 +587,14 @@ func (s *Service) AddComponent(ctx context.Context, component *repository.Compon return err } - s.repository.Register(components.Components...) + registerComponents := append([]*repository.Component{}, components.Components...) + if reportComponent, err := repository.BuildReportComponent(s.repository.Registry().Dispatcher(), components.Components[0]); err != nil { + return err + } else if reportComponent != nil { + registerComponents = append(registerComponents, reportComponent) + } + + s.repository.Register(registerComponents...) return nil } diff --git a/service/executor/expand/data_unit.go b/service/executor/expand/data_unit.go index 9d2c22dad..7f0358b2e 100644 --- a/service/executor/expand/data_unit.go +++ b/service/executor/expand/data_unit.go @@ -3,7 +3,9 @@ package expand import ( "context" "fmt" + "os" "reflect" + "runtime/debug" "strings" "sync" @@ -64,6 +66,9 @@ func (c *DataUnit) Validate(dest interface{}, opts ...interface{}) (*validator.V } func (c *DataUnit) Allocate(tableName string, dest interface{}, selector string) (string, error) { + if os.Getenv("DATLY_DEBUG_MUTABLE") == "1" { + fmt.Printf("[MUTABLE DEBUG] Allocate table=%s selector=%s destType=%T dest=%#v\n", tableName, selector, dest, dest) + } db, err := c.MetaSource.Db() if err != nil { fmt.Printf("error occured while connecting to DB %v\n", err.Error()) @@ -208,6 +213,15 @@ func (c *DataUnit) FilterExecutables(statements []string, stopOnNonExec bool) [] } func (c *DataUnit) In(columnName string, args interface{}) (string, error) { + if os.Getenv("DATLY_DEBUG_DATAUNIT") == "1" { + defer func() { + if r := recover(); r != nil { + fmt.Printf("[DATAUNIT DEBUG] In panic column=%q argsType=%T args=%#v err=%v\n%s\n", columnName, args, args, r, debug.Stack()) + panic(r) + } + }() + fmt.Printf("[DATAUNIT DEBUG] In column=%q argsType=%T args=%#v\n", columnName, args, args) + } return c.in(columnName, args, true) } diff --git a/service/executor/expand/parent.go b/service/executor/expand/parent.go index 11ae1162f..97787429d 100644 --- a/service/executor/expand/parent.go +++ b/service/executor/expand/parent.go @@ -2,8 +2,10 @@ package expand import ( "database/sql" + "fmt" "github.com/viant/datly/utils/types" "github.com/viant/xunsafe" + "os" "reflect" "strings" ) @@ -247,10 +249,16 @@ func NotZeroOf(values ...int) int { } func (c *DataUnit) Insert(data interface{}, tableName string) (string, error) { + if os.Getenv("DATLY_DEBUG_MUTABLE") == "1" { + fmt.Printf("[MUTABLE DEBUG] Insert table=%s dataType=%T data=%#v\n", tableName, data, data) + } return c.Statements.InsertWithMarker(tableName, data), nil } func (c *DataUnit) Update(data interface{}, tableName string) (string, error) { + if os.Getenv("DATLY_DEBUG_MUTABLE") == "1" { + fmt.Printf("[MUTABLE DEBUG] Update table=%s dataType=%T data=%#v\n", tableName, data, data) + } return c.Statements.UpdateWithMarker(tableName, data), nil } diff --git a/service/operator/reader.go b/service/operator/reader.go index 801638e12..5a29d9eeb 100644 --- a/service/operator/reader.go +++ b/service/operator/reader.go @@ -3,6 +3,7 @@ package operator import ( "context" "net/http" + "os" "github.com/viant/datly/repository" "github.com/viant/datly/service/reader" @@ -23,6 +24,9 @@ func (s *Service) runQuery(ctx context.Context, component *repository.Component, defer func() { if r := recover(); r != nil { panicMsg := fmt.Sprintf("Panic occurred: %v, Stack trace: %v", r, string(debug.Stack())) + if os.Getenv("DATLY_DEBUG_OPERATOR") == "1" { + fmt.Printf("[OPERATOR DEBUG] %s\n", panicMsg) + } logger := aSession.Logger() if logger == nil { panic(panicMsg) diff --git a/service/reader/anonymous_mysql_test.go b/service/reader/anonymous_mysql_test.go new file mode 100644 index 000000000..3a7c761eb --- /dev/null +++ b/service/reader/anonymous_mysql_test.go @@ -0,0 +1,171 @@ +package reader + +import ( + "context" + "database/sql" + "os" + "reflect" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + vstate "github.com/viant/datly/view/state" + sqlxread "github.com/viant/sqlx/io/read" +) + +func TestSQLXReader_AnonymousVsNamedPatchType(t *testing.T) { + if os.Getenv("TEST") != "1" { + t.Skip("set TEST=1 to run integration reader check") + } + + db, err := sql.Open("mysql", "root:dev@tcp(localhost:3306)/dev?parseTime=true") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + t.Run("anonymous", func(t *testing.T) { + type anonymousHas struct { + Id bool + Name bool + Quantity bool + } + type anonymousRow struct { + Id int + Name *string + Quantity *int + Has *anonymousHas + } + + reader, err := sqlxread.New(context.Background(), db, "SELECT * FROM FOOS WHERE ID = 4", func() interface{} { + return &anonymousRow{} + }) + require.NoError(t, err) + + var rows []*anonymousRow + err = reader.QueryAll(context.Background(), func(row interface{}) error { + rows = append(rows, row.(*anonymousRow)) + return nil + }) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, 4, rows[0].Id) + }) + + t.Run("named-reflect-structof", func(t *testing.T) { + hasType := reflect.StructOf([]reflect.StructField{ + {Name: "Id", Type: reflect.TypeOf(true)}, + {Name: "Name", Type: reflect.TypeOf(true)}, + {Name: "Quantity", Type: reflect.TypeOf(true)}, + }) + rowType := reflect.StructOf([]reflect.StructField{ + {Name: "Id", Type: reflect.TypeOf(int(0)), Tag: `sqlx:"ID"`}, + {Name: "Name", Type: reflect.TypeOf((*string)(nil)), Tag: `sqlx:"NAME"`}, + {Name: "Quantity", Type: reflect.TypeOf((*int)(nil)), Tag: `sqlx:"QUANTITY"`}, + {Name: "Has", Type: reflect.PtrTo(hasType), Tag: `setMarker:"true" format:"-" sqlx:"-" diff:"-" json:"-"`}, + }) + + reader, err := sqlxread.New(context.Background(), db, "SELECT * FROM FOOS WHERE ID = 4", func() interface{} { + return reflect.New(rowType).Interface() + }) + require.NoError(t, err) + + var rows []interface{} + err = reader.QueryAll(context.Background(), func(row interface{}) error { + rows = append(rows, row) + return nil + }) + require.NoError(t, err) + require.Len(t, rows, 1) + }) + + t.Run("collector-backed-anonymous", func(t *testing.T) { + type anonymousHas struct { + Id bool + Name bool + Quantity bool + } + type anonymousRow struct { + Id int `sqlx:"ID"` + Name *string `sqlx:"NAME"` + Quantity *int `sqlx:"QUANTITY"` + Has *anonymousHas `setMarker:"true" format:"-" sqlx:"-" diff:"-" json:"-"` + } + + aView := &view.View{ + Name: "CurFoos", + Schema: vstate.NewSchema(reflect.TypeOf([]*anonymousRow{})), + } + aView.Schema.Cardinality = vstate.Many + collector := view.NewCollector(aView.Schema.Slice(), aView, &[]*anonymousRow{}, nil, false) + reader, err := sqlxread.New(context.Background(), db, "SELECT * FROM FOOS WHERE ID = 4", collector.NewItem()) + require.NoError(t, err) + + err = reader.QueryAll(context.Background(), collector.Visitor(context.Background())) + require.NoError(t, err) + dest := collector.Dest().([]*anonymousRow) + require.Len(t, dest, 1) + require.Equal(t, 4, dest[0].Id) + }) + + t.Run("collector-backed-reflect-structof", func(t *testing.T) { + hasType := reflect.StructOf([]reflect.StructField{ + {Name: "Id", Type: reflect.TypeOf(true)}, + {Name: "Name", Type: reflect.TypeOf(true)}, + {Name: "Quantity", Type: reflect.TypeOf(true)}, + }) + rowType := reflect.StructOf([]reflect.StructField{ + {Name: "Id", Type: reflect.TypeOf(int(0)), Tag: `sqlx:"ID"`}, + {Name: "Name", Type: reflect.TypeOf((*string)(nil)), Tag: `sqlx:"NAME"`}, + {Name: "Quantity", Type: reflect.TypeOf((*int)(nil)), Tag: `sqlx:"QUANTITY"`}, + {Name: "Has", Type: reflect.PtrTo(hasType), Tag: `setMarker:"true" format:"-" sqlx:"-" diff:"-" json:"-"`}, + }) + sliceType := reflect.SliceOf(reflect.PtrTo(rowType)) + aView := &view.View{ + Name: "CurFoos", + Schema: vstate.NewSchema(sliceType), + } + aView.Schema.Cardinality = vstate.Many + + destPtr := reflect.New(sliceType).Interface() + collector := view.NewCollector(aView.Schema.Slice(), aView, destPtr, nil, false) + reader, err := sqlxread.New(context.Background(), db, "SELECT * FROM FOOS WHERE ID = 4", collector.NewItem()) + require.NoError(t, err) + + err = reader.QueryAll(context.Background(), collector.Visitor(context.Background())) + require.NoError(t, err) + destValue := reflect.ValueOf(collector.Dest()) + require.Equal(t, 1, destValue.Len()) + require.Equal(t, int64(4), destValue.Index(0).Elem().FieldByName("Id").Int()) + }) + + t.Run("collector-backed-reflect-structof-v1-order", func(t *testing.T) { + hasType := reflect.StructOf([]reflect.StructField{ + {Name: "Name", Type: reflect.TypeOf(true)}, + {Name: "Quantity", Type: reflect.TypeOf(true)}, + {Name: "Id", Type: reflect.TypeOf(true)}, + }) + rowType := reflect.StructOf([]reflect.StructField{ + {Name: "Name", Type: reflect.TypeOf((*string)(nil)), Tag: `sqlx:"NAME"`}, + {Name: "Quantity", Type: reflect.TypeOf((*int)(nil)), Tag: `sqlx:"QUANTITY"`}, + {Name: "Id", Type: reflect.TypeOf(int(0)), Tag: `sqlx:"ID"`}, + {Name: "Has", Type: reflect.PtrTo(hasType), Tag: `setMarker:"true" format:"-" sqlx:"-" diff:"-" json:"-"`}, + }) + sliceType := reflect.SliceOf(reflect.PtrTo(rowType)) + aView := &view.View{ + Name: "CurFoos", + Schema: vstate.NewSchema(sliceType), + } + aView.Schema.Cardinality = vstate.Many + + destPtr := reflect.New(sliceType).Interface() + collector := view.NewCollector(aView.Schema.Slice(), aView, destPtr, nil, false) + reader, err := sqlxread.New(context.Background(), db, "SELECT * FROM FOOS WHERE ID = 4", collector.NewItem()) + require.NoError(t, err) + + err = reader.QueryAll(context.Background(), collector.Visitor(context.Background())) + require.NoError(t, err) + destValue := reflect.ValueOf(collector.Dest()) + require.Equal(t, 1, destValue.Len()) + require.Equal(t, int64(4), destValue.Index(0).Elem().FieldByName("Id").Int()) + }) +} diff --git a/service/reader/handler/handler.go b/service/reader/handler/handler.go index 6e83335e3..98a3e47d3 100644 --- a/service/reader/handler/handler.go +++ b/service/reader/handler/handler.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "fmt" "github.com/viant/datly/gateway/router/status" _ "github.com/viant/datly/repository/locator/async" @@ -12,7 +13,9 @@ import ( _ "github.com/viant/datly/service/executor/handler/locator" "net/http" + "os" "reflect" + "runtime/debug" reader "github.com/viant/datly/service/reader" "github.com/viant/datly/service/session" @@ -96,6 +99,14 @@ func (h *Handler) Handle(ctx context.Context, aView *view.View, aSession *sessio } func (h *Handler) readData(ctx context.Context, aView *view.View, aState *session.Session, ret *Response, opts []reader.Option) error { + if os.Getenv("DATLY_DEBUG_HANDLER_READ") == "1" { + defer func() { + if r := recover(); r != nil { + fmt.Printf("[HANDLER READ DEBUG] panic view=%s err=%v\n%s\n", aView.Name, r, debug.Stack()) + panic(r) + } + }() + } destValue := reflect.New(aView.Schema.SliceType()) dest := destValue.Interface() aSession, err := reader.NewSession(dest, aView) diff --git a/service/reader/service.go b/service/reader/service.go index 1fda9fc9e..d362ec899 100644 --- a/service/reader/service.go +++ b/service/reader/service.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" "fmt" + "os" "reflect" + "runtime/debug" "strings" "sync" "sync/atomic" @@ -35,6 +37,14 @@ type Service struct { // ReadInto reads Data into provided destination, * dDest` is required. It has to be a pointer to `interface{}` or pointer to slice of `T` or `*T` func (s *Service) ReadInto(ctx context.Context, dest interface{}, aView *view.View, opts ...Option) error { + if os.Getenv("DATLY_DEBUG_READER") == "1" { + defer func() { + if r := recover(); r != nil { + fmt.Printf("[READER DEBUG] panic view=%s dest=%T err=%v\n%s\n", aView.Name, dest, r, debug.Stack()) + panic(r) + } + }() + } session, err := NewSession(dest, aView, opts...) if err != nil { return err @@ -408,6 +418,9 @@ func (s *Service) BuildCriteria(ctx context.Context, value interface{}, options } func (s *Service) queryInBatches(ctx context.Context, session *Session, aView *view.View, collector *view.Collector, visitor view.VisitorFn, info *response.SQLExecutions, batchData *view.BatchData, selector *view.Statelet) error { + if os.Getenv("DATLY_DEBUG_QUERY_HANDLER") == "1" { + fmt.Printf("[QUERY DEBUG] queryInBatches view=%s selectorTemplateNil=%v batchValues=%d\n", aView.Name, selector == nil || selector.Template == nil, len(batchData.ValuesBatch)) + } wg := &sync.WaitGroup{} db, err := aView.Db() if err != nil { @@ -446,10 +459,19 @@ func (s *Service) queryObjects(ctx context.Context, session *Session, aView *vie return s.queryWithPartitions(ctx, session, aView, selector, batchData, db, collector, visitor, partitioned) } readData := 0 + if os.Getenv("DATLY_DEBUG_QUERY_HANDLER") == "1" { + fmt.Printf("[QUERY DEBUG] queryObjects view=%s schema=%v slice=%v collectorView=%s\n", aView.Name, aView.Schema.Type(), aView.Schema.SliceType(), collector.View().Name) + } parametrizedSQL, columnInMatcher, err := s.buildParametrizedSQL(ctx, aView, selector, batchData, collector, session, nil) if err != nil { + if os.Getenv("DATLY_DEBUG_QUERY_HANDLER") == "1" { + fmt.Printf("[QUERY DEBUG] buildParametrizedSQL error view=%s err=%v\n", aView.Name, err) + } return nil, err } + if os.Getenv("DATLY_DEBUG_QUERY_HANDLER") == "1" { + fmt.Printf("[QUERY DEBUG] builtSQL view=%s sql=%s args=%#v\n", aView.Name, parametrizedSQL.SQL, parametrizedSQL.Args) + } var parentProvider func(value interface{}) (interface{}, error) handler := func(row interface{}) error { @@ -514,6 +536,9 @@ func (s *Service) queryWithHandler(ctx context.Context, session *Session, aView stats, onDone := NewExecutionInfo(parametrizedSQL, cacheStats, collector) defer onDone() + if os.Getenv("DATLY_DEBUG_QUERY_HANDLER") == "1" { + fmt.Printf("[QUERY HANDLER] view=%s sql=%s args=%#v\n", aView.Name, parametrizedSQL.SQL, parametrizedSQL.Args) + } if session.DryRun { return []*response.SQLExecution{stats}, nil } @@ -543,7 +568,16 @@ BEGIN: } _ = stmt.Close() }() - err = reader.QueryAll(ctx, handler, parametrizedSQL.Args...) + debugHandler := handler + if os.Getenv("DATLY_DEBUG_QUERY_HANDLER") == "1" { + debugHandler = func(row interface{}) error { + fmt.Printf("[QUERY HANDLER] view=%s before unwrap row=%T readData=%d\n", aView.Name, row, *readData) + err := handler(row) + fmt.Printf("[QUERY HANDLER] view=%s after handler row=%T readData=%d err=%v\n", aView.Name, row, *readData, err) + return err + } + } + err = reader.QueryAll(ctx, debugHandler, parametrizedSQL.Args...) isInvalidConnection = err != nil && strings.Contains(err.Error(), "invalid connection") if isInvalidConnection && atomic.AddUint32(&retires, 1) < 3 { diff --git a/service/reader/sql.go b/service/reader/sql.go index 256cd26f3..53af23f7e 100644 --- a/service/reader/sql.go +++ b/service/reader/sql.go @@ -3,14 +3,19 @@ package reader import ( "context" "fmt" + "os" "strconv" "strings" + "github.com/viant/datly/internal/inference" "github.com/viant/datly/service/executor/expand" "github.com/viant/datly/service/reader/metadata" "github.com/viant/datly/shared" "github.com/viant/datly/view" "github.com/viant/datly/view/keywords" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/query" "github.com/viant/sqlx/io/read/cache" ) @@ -76,15 +81,21 @@ func (b *Builder) Build(ctx context.Context, opts ...BuilderOption) (*cache.Parm if len(state.Filters) > 0 { statelet.AppendFilters(state.Filters) } - if aView.Template.IsActualTemplate() && aView.ShouldTryDiscover() { - state.Expanded = metadata.EnrichWithDiscover(state.Expanded, true) - } sb := strings.Builder{} sb.WriteString(selectFragment) - if err = b.appendColumns(&sb, aView, statelet); err != nil { + projectedColumns, err := b.appendColumns(&sb, aView, statelet) + if err != nil { return nil, err } + if aView.Groupable { + if state.Expanded, err = b.rewriteGroupBy(state.Expanded, aView.Columns, projectedColumns); err != nil { + return nil, err + } + } + if aView.Template.IsActualTemplate() && aView.ShouldTryDiscover() { + state.Expanded = metadata.EnrichWithDiscover(state.Expanded, true) + } if err = b.appendRelationColumn(&sb, aView, statelet, relation); err != nil { return nil, err @@ -159,6 +170,9 @@ func (b *Builder) Build(ctx context.Context, opts ...BuilderOption) (*cache.Parm SQL: SQL, Args: placeholders, } + if os.Getenv("DATLY_DEBUG_SQL_BUILDER") == "1" { + fmt.Printf("[SQL BUILDER] view=%s sql=%s args=%#v state=%s\n", aView.Name, SQL, placeholders, state.Expanded) + } if exclude.ColumnsIn && relation != nil { parametrizedQuery.By = shared.FirstNotEmpty(relation.Of.On[0].Field, relation.Of.On[0].Column) @@ -173,20 +187,21 @@ func (b *Builder) Build(ctx context.Context, opts ...BuilderOption) (*cache.Parm return parametrizedQuery, err } -func (b *Builder) appendColumns(sb *strings.Builder, aView *view.View, selector *view.Statelet) error { +func (b *Builder) appendColumns(sb *strings.Builder, aView *view.View, selector *view.Statelet) ([]*view.Column, error) { if len(selector.Columns) == 0 { b.appendViewColumns(sb, aView) - return nil + return nil, nil } return b.appendSelectorColumns(sb, aView, selector) } -func (b *Builder) appendSelectorColumns(sb *strings.Builder, view *view.View, selector *view.Statelet) error { +func (b *Builder) appendSelectorColumns(sb *strings.Builder, aView *view.View, selector *view.Statelet) ([]*view.Column, error) { + result := make([]*view.Column, 0, len(selector.Columns)) for i, column := range selector.Columns { - viewColumn, ok := view.ColumnByName(column) + viewColumn, ok := aView.ColumnByName(column) if !ok { - return fmt.Errorf("not found column %v at view %v", column, view.Name) + return nil, fmt.Errorf("not found column %v at view %v", column, aView.Name) } if i != 0 { @@ -194,10 +209,33 @@ func (b *Builder) appendSelectorColumns(sb *strings.Builder, view *view.View, se } sb.WriteString(" ") - sb.WriteString(viewColumn.SqlExpression()) + if aView.Groupable { + sb.WriteString(groupedProjectionExpression(viewColumn)) + } else { + sb.WriteString(viewColumn.SqlExpression()) + } + result = append(result, viewColumn) } - return nil + return result, nil +} + +func groupedProjectionExpression(column *view.Column) string { + if column == nil { + return "" + } + expr := column.Name + if defaultValue := columnDefaultValue(column); defaultValue != "" { + return "COALESCE(" + expr + "," + defaultValue + ") AS " + column.Name + } + return expr +} + +func columnDefaultValue(column *view.Column) string { + if column == nil { + return "" + } + return column.DefaultValue() } func (b *Builder) viewAlias(view *view.View) string { @@ -225,6 +263,149 @@ func (b *Builder) appendViewColumns(sb *strings.Builder, view *view.View) { } } +func (b *Builder) rewriteGroupBy(SQL string, allColumns []*view.Column, projectedColumns []*view.Column) (string, error) { + if len(projectedColumns) == 0 { + return SQL, nil + } + + trimmed := strings.TrimSpace(SQL) + if trimmed == "" { + return SQL, nil + } + wrapped := strings.HasPrefix(trimmed, "(") && strings.HasSuffix(trimmed, ")") + querySQL := inference.TrimParenthesis(trimmed) + parsed, err := sqlparser.ParseQuery(querySQL) + if err != nil || parsed == nil { + return SQL, err + } + + selectedPositions := projectedColumnPositions(allColumns, projectedColumns) + if len(selectedPositions) > 0 { + items := make(query.List, 0, len(selectedPositions)) + for _, position := range selectedPositions { + if position <= 0 || position > len(parsed.List) { + continue + } + items = append(items, parsed.List[position-1]) + } + if len(items) > 0 { + parsed.List = items + } + } + + positions := projectedGroupByPositions(parsed.List, projectedColumns) + groupBy := make(query.List, 0, len(positions)) + for _, position := range positions { + groupBy = append(groupBy, query.NewItem(expr.NewIntLiteral(strconv.Itoa(position)))) + } + parsed.GroupBy = groupBy + parsed.OrderBy = filterGroupedOrderBy(parsed.OrderBy, parsed.List) + + rewritten := sqlparser.Stringify(parsed) + if wrapped { + rewritten = "(" + rewritten + ")" + } + return rewritten, nil +} + +func projectedColumnPositions(allColumns []*view.Column, projectedColumns []*view.Column) []int { + index := make(map[*view.Column]int, len(allColumns)) + for i, column := range allColumns { + index[column] = i + 1 + } + result := make([]int, 0, len(projectedColumns)) + seen := map[int]bool{} + for _, column := range projectedColumns { + if column == nil { + continue + } + position, ok := index[column] + if !ok || seen[position] { + continue + } + seen[position] = true + result = append(result, position) + } + return result +} + +func filterGroupedOrderBy(orderBy query.List, items query.List) query.List { + if len(orderBy) == 0 || len(items) == 0 { + return orderBy + } + allowed := map[string]bool{} + for _, item := range items { + if item == nil { + continue + } + if item.Expr != nil { + allowed[normalizeExpression(sqlparser.Stringify(item.Expr))] = true + } + if item.Alias != "" { + allowed[normalizeExpression(item.Alias)] = true + } + } + result := make(query.List, 0, len(orderBy)) + for _, item := range orderBy { + if item == nil || item.Expr == nil { + continue + } + if allowed[normalizeExpression(sqlparser.Stringify(item.Expr))] { + result = append(result, item) + } + } + return result +} + +func normalizeExpression(value string) string { + return strings.ToUpper(strings.Join(strings.Fields(strings.TrimSpace(value)), " ")) +} + +func projectedGroupByPositions(items query.List, projectedColumns []*view.Column) []int { + maxLen := len(items) + if len(projectedColumns) < maxLen { + maxLen = len(projectedColumns) + } + result := make([]int, 0, maxLen) + for i := 0; i < maxLen; i++ { + column := projectedColumns[i] + if column != nil && column.Groupable { + result = append(result, i+1) + continue + } + if !isAggregateSelectItem(items[i]) { + result = append(result, i+1) + } + } + return result +} + +func isAggregateSelectItem(item *query.Item) bool { + if item == nil || item.Expr == nil { + return false + } + call, ok := item.Expr.(*expr.Call) + if !ok || call.X == nil { + return false + } + switch actual := call.X.(type) { + case *expr.Ident: + return isAggregateFunction(actual.Name) + case *expr.Selector: + return isAggregateFunction(actual.Name) + } + return false +} + +func isAggregateFunction(name string) bool { + switch strings.ToUpper(strings.TrimSpace(name)) { + case "SUM", "COUNT", "AVG", "MIN", "MAX", "ARRAY_AGG", "STRING_AGG", "ANY_VALUE": + return true + default: + return false + } +} + func (b *Builder) appendViewAlias(sb *strings.Builder, view *view.View) { if view.Alias == "" { return @@ -413,7 +594,7 @@ func (b *Builder) appendRelationColumn(sb *strings.Builder, aView *view.View, se } func (b *Builder) checkViewAndAppendRelColumn(sb *strings.Builder, aView *view.View, relation *view.Relation) error { - if _, ok := aView.ColumnByName(relation.Of.On[0].Column); ok { + if _, _, ok := b.lookupRelationColumn(aView, relation); ok { return nil } @@ -431,13 +612,16 @@ func (b *Builder) checkViewAndAppendRelColumn(sb *strings.Builder, aView *view.V } func (b *Builder) checkSelectorAndAppendRelColumn(sb *strings.Builder, aView *view.View, selector *view.Statelet, relation *view.Relation) error { - if relation == nil || selector.Has(relation.Of.On[0].Column) || aView.Template.IsActualTemplate() { + if relation == nil || aView.Template.IsActualTemplate() { + return nil + } + if b.selectorHasRelationColumn(selector, aView, relation) { return nil } sb.WriteString(separatorFragment) sb.WriteString(" ") - col, ok := aView.ColumnByName(relation.Of.On[0].Column) + col, _, ok := b.lookupRelationColumn(aView, relation) if !ok { sb.WriteString(relation.Of.On[0].Column) } else { @@ -447,6 +631,46 @@ func (b *Builder) checkSelectorAndAppendRelColumn(sb *strings.Builder, aView *vi return nil } +func (b *Builder) selectorHasRelationColumn(selector *view.Statelet, aView *view.View, relation *view.Relation) bool { + if selector == nil || relation == nil || relation.Of == nil || len(relation.Of.On) == 0 { + return false + } + link := relation.Of.On[0] + if selector.Has(link.Column) { + return true + } + if link.Field != "" && selector.Has(link.Field) { + return true + } + if column, _, ok := b.lookupRelationColumn(aView, relation); ok { + if selector.Has(column.Name) { + return true + } + if field := column.Field(); field != nil && selector.Has(field.Name) { + return true + } + } + return false +} + +func (b *Builder) lookupRelationColumn(aView *view.View, relation *view.Relation) (*view.Column, string, bool) { + if aView == nil || relation == nil || relation.Of == nil || len(relation.Of.On) == 0 { + return nil, "", false + } + link := relation.Of.On[0] + if link.Field != "" { + if column, ok := aView.ColumnByName(link.Field); ok { + return column, link.Field, true + } + } + if link.Column != "" { + if column, ok := aView.ColumnByName(link.Column); ok { + return column, link.Column, true + } + } + return nil, "", false +} + func actualLimit(aView *view.View, selector *view.Statelet) int { if selector.Limit != 0 { return selector.Limit diff --git a/service/reader/sql_groupable_test.go b/service/reader/sql_groupable_test.go new file mode 100644 index 000000000..7cb4bbecc --- /dev/null +++ b/service/reader/sql_groupable_test.go @@ -0,0 +1,335 @@ +package reader + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" +) + +func TestBuilder_appendColumns(t *testing.T) { + testView := newGroupableTestView(t) + builder := NewBuilder() + + useCases := []struct { + description string + selector *view.Statelet + expectNames []string + expectNil bool + expectedSQL string + }{ + { + description: "default projection keeps view column order", + selector: view.NewStatelet(), + expectNil: true, + expectedSQL: " t.region_id, t.total_sales, t.country_id", + }, + { + description: "selector projection keeps requested order", + selector: func() *view.Statelet { + selector := view.NewStatelet() + selector.Columns = []string{"country_id", "region_id"} + return selector + }(), + expectNames: []string{"country_id", "region_id"}, + expectedSQL: " country_id, region_id", + }, + { + description: "grouped selector projection uses derived aliases for aggregate columns", + selector: func() *view.Statelet { + selector := view.NewStatelet() + selector.Columns = []string{"account_id", "total_id", "max_id"} + return selector + }(), + expectNames: []string{"account_id", "total_id", "max_id"}, + expectedSQL: " account_id, total_id, max_id", + }, + } + + for _, useCase := range useCases { + t.Run(useCase.description, func(t *testing.T) { + sb := &strings.Builder{} + viewUnderTest := testView + if useCase.description == "grouped selector projection uses derived aliases for aggregate columns" { + viewUnderTest = aggregateSelectorTestView(t) + } + projected, err := builder.appendColumns(sb, viewUnderTest, useCase.selector) + require.NoError(t, err) + require.Equal(t, useCase.expectedSQL, sb.String()) + if useCase.expectNil { + require.Nil(t, projected) + return + } + require.Equal(t, useCase.expectNames, columnNames(projected)) + }) + } +} + +func TestBuilder_rewriteGroupBy(t *testing.T) { + testView := newGroupableTestView(t) + aggregateColumns := aggregateGroupableColumns() + groupedMetrics := groupedMetricsColumns() + builder := NewBuilder() + + useCases := []struct { + description string + sql string + allColumns []*view.Column + projected []*view.Column + expected string + }{ + { + description: "replace existing group by with selected original positions", + sql: "(SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY 1, 3)", + allColumns: testView.Columns, + projected: []*view.Column{testView.Columns[2], testView.Columns[1]}, + expected: "(SELECT country_id, SUM(total_sales) AS total_sales FROM sales GROUP BY 1)", + }, + { + description: "remove group by when no selected projected column is groupable", + sql: "(SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY 1, 3)", + allColumns: testView.Columns, + projected: []*view.Column{testView.Columns[1]}, + expected: "(SELECT SUM(total_sales) AS total_sales FROM sales)", + }, + { + description: "add group by when absent", + sql: "(SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales)", + allColumns: testView.Columns, + projected: []*view.Column{testView.Columns[0], testView.Columns[1]}, + expected: "(SELECT region_id, SUM(total_sales) AS total_sales FROM sales GROUP BY 1)", + }, + { + description: "skip rewrite when no specific projection was selected", + sql: "(SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY 1, 3)", + allColumns: testView.Columns, + projected: nil, + expected: "(SELECT region_id, SUM(total_sales) AS total_sales, country_id FROM sales GROUP BY 1, 3)", + }, + { + description: "rewrite grouped aggregates to selected groupable positions only", + sql: "(SELECT account_id, user_created, SUM(id) AS total_id, MAX(id) AS max_id FROM vendor GROUP BY 1, 2)", + allColumns: aggregateColumns, + projected: []*view.Column{aggregateColumns[0], aggregateColumns[2], aggregateColumns[3]}, + expected: "(SELECT account_id, SUM(id) AS total_id, MAX(id) AS max_id FROM vendor GROUP BY 1)", + }, + { + description: "rewrite grouped metrics query prunes unselected dimensions from select list", + sql: "(SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, SUM(p.bids) AS bids, SUM(p.impressions) AS impressions, SUM(p.clicks) AS clicks, SUM(p.conversions) AS conversions, SUM(p.total_spend) AS total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL ? DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?))) GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 LIMIT 1000)", + allColumns: groupedMetrics, + projected: []*view.Column{ + groupedMetrics[0], + groupedMetrics[1], + groupedMetrics[2], + groupedMetrics[3], + groupedMetrics[4], + groupedMetrics[5], + groupedMetrics[6], + }, + expected: "(SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id FROM `viant-mediator.forecaster.fact_perf_daily_mv` p WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL ? DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?))) GROUP BY 1, 2, 3, 4, 5, 6, 7)", + }, + { + description: "rewrite grouped metrics CTE prunes unselected dimensions and preserves order", + sql: "WITH params AS (SELECT CAST(GREATEST(?, 1) AS INT64) AS date_interval), last_n AS (SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, p.bids, p.impressions, p.clicks, p.conversions, p.total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p JOIN params prm ON TRUE WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL prm.date_interval DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?)))) SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, SUM(p.bids) AS bids, SUM(p.impressions) AS impressions, SUM(p.clicks) AS clicks, SUM(p.conversions) AS conversions, SUM(p.total_spend) AS total_spend FROM last_n p GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 ORDER BY p.event_date", + allColumns: groupedMetrics, + projected: []*view.Column{ + groupedMetrics[0], + groupedMetrics[1], + groupedMetrics[2], + groupedMetrics[3], + groupedMetrics[4], + groupedMetrics[5], + groupedMetrics[6], + }, + expected: "WITH params AS (SELECT CAST(GREATEST(?, 1) AS INT64) AS date_interval), last_n AS (SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, p.bids, p.impressions, p.clicks, p.conversions, p.total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p JOIN params prm ON TRUE WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL prm.date_interval DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?)))) SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id FROM last_n p GROUP BY 1, 2, 3, 4, 5, 6, 7 ORDER BY p.event_date", + }, + { + description: "rewrite grouped metrics CTE keeps selected non aggregate site_type in group by", + sql: "WITH params AS (SELECT CAST(GREATEST(?, 1) AS INT64) AS date_interval), last_n AS (SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, p.bids, p.impressions, p.clicks, p.conversions, p.total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p JOIN params prm ON TRUE WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL prm.date_interval DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?)) AND ((p.campaign_id IN (?))))) SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.site_type, SUM(p.bids) AS bids, SUM(p.impressions) AS impressions, SUM(p.clicks) AS clicks, SUM(p.conversions) AS conversions, SUM(p.total_spend) AS total_spend FROM last_n p GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 ORDER BY p.event_date LIMIT 1000", + allColumns: func() []*view.Column { + cloned := cloneColumns(groupedMetrics) + cloned[10].Groupable = false + return cloned + }(), + projected: func() []*view.Column { + cloned := cloneColumns(groupedMetrics) + cloned[10].Groupable = false + return []*view.Column{ + cloned[0], + cloned[1], + cloned[2], + cloned[3], + cloned[4], + cloned[5], + cloned[10], + cloned[11], + cloned[12], + cloned[13], + cloned[14], + cloned[15], + } + }(), + expected: "WITH params AS (SELECT CAST(GREATEST(?, 1) AS INT64) AS date_interval), last_n AS (SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, p.bids, p.impressions, p.clicks, p.conversions, p.total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p JOIN params prm ON TRUE WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL prm.date_interval DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?)) AND ((p.campaign_id IN (?))))) SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.site_type, SUM(p.bids) AS bids, SUM(p.impressions) AS impressions, SUM(p.clicks) AS clicks, SUM(p.conversions) AS conversions, SUM(p.total_spend) AS total_spend FROM last_n p GROUP BY 1, 2, 3, 4, 5, 6, 7 ORDER BY p.event_date", + }, + { + description: "rewrite grouped metrics with publisher subset renumbers group by after pruning", + sql: "(SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, SUM(p.bids) AS bids, SUM(p.impressions) AS impressions, SUM(p.clicks) AS clicks, SUM(p.conversions) AS conversions, SUM(p.total_spend) AS total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL ? DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?))) GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 LIMIT 1000)", + allColumns: groupedMetrics, + projected: []*view.Column{ + groupedMetrics[0], + groupedMetrics[1], + groupedMetrics[2], + groupedMetrics[3], + groupedMetrics[4], + groupedMetrics[5], + groupedMetrics[7], + }, + expected: "(SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.publisher_id FROM `viant-mediator.forecaster.fact_perf_daily_mv` p WHERE p.event_date BETWEEN DATE_SUB(CURRENT_DATE(), INTERVAL ? DAY) AND DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY) AND (((p.agency_id = ?))) GROUP BY 1, 2, 3, 4, 5, 6, 7)", + }, + { + description: "rewrite grouped report projection drops order by on pruned dimension", + sql: "WITH last_n AS (SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, p.bids, p.impressions, p.clicks, p.conversions, p.total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p WHERE p.event_date BETWEEN DATE(DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY)) AND DATE(CURRENT_DATE()-1)) SELECT p.ad_order_id, SUM(p.bids) AS bids FROM last_n p GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 ORDER BY p.event_date LIMIT 1000", + allColumns: groupedMetrics, + projected: []*view.Column{ + groupedMetrics[4], + groupedMetrics[11], + }, + expected: "WITH last_n AS (SELECT p.event_date, p.agency_id, p.advertiser_id, p.campaign_id, p.ad_order_id, p.audience_id, p.deal_id, p.publisher_id, p.channel_id, p.country, p.site_type, p.bids, p.impressions, p.clicks, p.conversions, p.total_spend FROM `viant-mediator.forecaster.fact_perf_daily_mv` p WHERE p.event_date BETWEEN DATE(DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY)) AND DATE(CURRENT_DATE()-1)) SELECT p.ad_order_id, SUM(p.bids) AS bids FROM last_n p GROUP BY 1", + }, + } + + for _, useCase := range useCases { + t.Run(useCase.description, func(t *testing.T) { + actual, err := builder.rewriteGroupBy(useCase.sql, useCase.allColumns, useCase.projected) + require.NoError(t, err) + require.Equal(t, normalizeSQL(useCase.expected), normalizeSQL(actual)) + }) + } +} + +func TestBuilder_appendRelationColumn_UsesProjectedRelationAliasForGroupedDerivedView(t *testing.T) { + builder := NewBuilder() + aView := view.NewView("disqualified", "disqualified", + view.WithConnector(view.NewConnector("test", "sqlite3", ":memory:")), + view.WithColumns(view.Columns{ + &view.Column{Name: "TAXONOMY_ID", DataType: "int"}, + &view.Column{Name: "IS_DISQUALIFIED", DataType: "int"}, + }), + ) + require.NoError(t, aView.Init(context.Background(), view.EmptyResource())) + + relation := &view.Relation{ + Of: &view.ReferenceView{ + On: view.Links{ + &view.Link{Field: "TaxonomyId", Column: "dq.SEGMENT_ID"}, + }, + }, + } + + t.Run("default projection does not append raw source column when projected alias exists", func(t *testing.T) { + sb := &strings.Builder{} + require.NoError(t, builder.checkViewAndAppendRelColumn(sb, aView, relation)) + require.Equal(t, "", sb.String()) + }) + + t.Run("selector projection appends projected alias expression instead of raw source column", func(t *testing.T) { + sb := &strings.Builder{} + selector := view.NewStatelet() + selector.Columns = []string{"IS_DISQUALIFIED"} + selector.Init(aView) + require.NoError(t, builder.checkSelectorAndAppendRelColumn(sb, aView, selector, relation)) + require.Equal(t, ", TAXONOMY_ID", sb.String()) + }) +} + +func newGroupableTestView(t *testing.T) *view.View { + t.Helper() + trueValue := true + aView := view.NewView("sales", "sales", + view.WithGroupable(true), + view.WithConnector(view.NewConnector("test", "sqlite3", ":memory:")), + view.WithColumns(view.Columns{ + &view.Column{Name: "region_id", DataType: "string"}, + &view.Column{Name: "total_sales", DataType: "float64"}, + &view.Column{Name: "country_id", DataType: "string"}, + }), + ) + aView.ColumnsConfig = map[string]*view.ColumnConfig{ + "region_id": {Name: "region_id", Groupable: &trueValue}, + "country_id": {Name: "country_id", Groupable: &trueValue}, + } + require.NoError(t, aView.Init(context.Background(), view.EmptyResource())) + return aView +} + +func aggregateGroupableColumns() []*view.Column { + return []*view.Column{ + {Name: "account_id", Groupable: true}, + {Name: "user_created", Groupable: true}, + {Name: "total_id"}, + {Name: "max_id"}, + } +} + +func aggregateSelectorTestView(t *testing.T) *view.View { + t.Helper() + aView := view.NewView("vendor", "vendor", + view.WithGroupable(true), + view.WithConnector(view.NewConnector("test", "sqlite3", ":memory:")), + view.WithColumns(view.Columns{ + &view.Column{Name: "account_id", DataType: "int", Groupable: true}, + &view.Column{Name: "user_created", DataType: "int", Groupable: true}, + &view.Column{Name: "total_id", DataType: "float64", Expression: "SUM(id)", Aggregate: true}, + &view.Column{Name: "max_id", DataType: "int", Expression: "MAX(id)", Aggregate: true}, + }), + ) + require.NoError(t, aView.Init(context.Background(), view.EmptyResource())) + return aView +} + +func groupedMetricsColumns() []*view.Column { + return []*view.Column{ + {Name: "event_date", Groupable: true}, + {Name: "agency_id", Groupable: true}, + {Name: "advertiser_id", Groupable: true}, + {Name: "campaign_id", Groupable: true}, + {Name: "ad_order_id", Groupable: true}, + {Name: "audience_id", Groupable: true}, + {Name: "deal_id", Groupable: true}, + {Name: "publisher_id", Groupable: true}, + {Name: "channel_id", Groupable: true}, + {Name: "country", Groupable: true}, + {Name: "site_type", Groupable: true}, + {Name: "bids"}, + {Name: "impressions"}, + {Name: "clicks"}, + {Name: "conversions"}, + {Name: "total_spend"}, + } +} + +func cloneColumns(columns []*view.Column) []*view.Column { + result := make([]*view.Column, len(columns)) + for i, column := range columns { + if column == nil { + continue + } + cloned := *column + result[i] = &cloned + } + return result +} + +func columnNames(columns []*view.Column) []string { + result := make([]string, len(columns)) + for i, column := range columns { + result[i] = column.Name + } + return result +} + +func normalizeSQL(SQL string) string { + return strings.Join(strings.Fields(SQL), " ") +} diff --git a/service/reader/sql_projection_regression_test.go b/service/reader/sql_projection_regression_test.go new file mode 100644 index 000000000..9592f741f --- /dev/null +++ b/service/reader/sql_projection_regression_test.go @@ -0,0 +1,55 @@ +package reader + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/sqlparser" +) + +func TestBuilder_appendColumns_UsesAliasesForDiscoveredExpressions(t *testing.T) { + builder := NewBuilder() + useCases := []struct { + description string + sql string + expectedSQL string + }{ + { + description: "case expression keeps outer alias projection", + sql: "SELECT (CASE WHEN 'user_name' = 'user_name' THEN u.STR_ID ELSE NULL END) AS VALUE FROM CI_EVENT ev LEFT JOIN CI_CONTACTS u ON ev.CREATED_USER = u.ID", + expectedSQL: " t.VALUE", + }, + { + description: "coalesce expression keeps discovered alias projection", + sql: "SELECT COALESCE(sl.APPROVED_SITE_CNT,0) AS NUMBER_OF_SITES FROM CI_SITE_LIST sl", + expectedSQL: " t.NUMBER_OF_SITES", + }, + } + + for _, useCase := range useCases { + t.Run(useCase.description, func(t *testing.T) { + parsed, err := sqlparser.ParseQuery(useCase.sql) + require.NoError(t, err) + columns := view.NewColumns(sqlparser.NewColumns(parsed.List), nil) + for _, column := range columns { + if strings.TrimSpace(column.DataType) == "" { + column.DataType = "string" + } + } + aView := view.NewView("projection", "projection", + view.WithConnector(view.NewConnector("test", "sqlite3", ":memory:")), + view.WithColumns(columns), + ) + require.NoError(t, aView.Init(context.Background(), view.EmptyResource())) + + sb := &strings.Builder{} + projected, err := builder.appendColumns(sb, aView, view.NewStatelet()) + require.NoError(t, err) + require.Nil(t, projected) + require.Equal(t, useCase.expectedSQL, sb.String()) + }) + } +} diff --git a/service/session/reader.go b/service/session/reader.go index 2184634e1..94a91d528 100644 --- a/service/session/reader.go +++ b/service/session/reader.go @@ -2,11 +2,22 @@ package session import ( "context" + "fmt" reader "github.com/viant/datly/service/reader" "github.com/viant/datly/view" + "os" + "runtime/debug" ) func (s *Session) ReadInto(ctx context.Context, dest interface{}, aView *view.View) error { + if os.Getenv("DATLY_DEBUG_READINTO") == "1" { + defer func() { + if r := recover(); r != nil { + fmt.Printf("[READINTO DEBUG] panic view=%s dest=%T err=%v\n%s\n", aView.Name, dest, r, debug.Stack()) + panic(r) + } + }() + } if err := s.SetViewState(ctx, aView); err != nil { return err } diff --git a/service/session/selector.go b/service/session/selector.go index ca8b64c76..bfcb0c2dd 100644 --- a/service/session/selector.go +++ b/service/session/selector.go @@ -8,6 +8,7 @@ import ( "github.com/viant/datly/service/session/criteria" "github.com/viant/datly/view" + "github.com/viant/datly/view/state" "github.com/viant/tagly/format/text" "github.com/viant/xdatly/codec" "github.com/viant/xdatly/handler/response" @@ -77,22 +78,22 @@ func (s *Session) setQuerySelector(ctx context.Context, ns *view.NamespaceView, injected = resolveInjectedQuerySelector(ns, opts.locatorOpt.QuerySelectors) } if err = s.populateFieldQuerySelector(ctx, ns, opts); err != nil { - return response.NewParameterError(ns.View.Name, selectorParameters.FieldsParameter.Name, err) + return response.NewParameterError(ns.View.Name, selectorParameterName(selectorParameters.FieldsParameter, view.QueryStateParameters.FieldsParameter), err) } if err = s.populateLimitQuerySelector(ctx, ns, opts); err != nil { - return response.NewParameterError(ns.View.Name, selectorParameters.LimitParameter.Name, err) + return response.NewParameterError(ns.View.Name, selectorParameterName(selectorParameters.LimitParameter, view.QueryStateParameters.LimitParameter), err) } if err = s.populateOffsetQuerySelector(ctx, ns, opts); err != nil { - return response.NewParameterError(ns.View.Name, selectorParameters.OffsetParameter.Name, err) + return response.NewParameterError(ns.View.Name, selectorParameterName(selectorParameters.OffsetParameter, view.QueryStateParameters.OffsetParameter), err) } if err = s.populateOrderByQuerySelector(ctx, ns, opts); err != nil { - return response.NewParameterError(ns.View.Name, selectorParameters.OrderByParameter.Name, err) + return response.NewParameterError(ns.View.Name, selectorParameterName(selectorParameters.OrderByParameter, view.QueryStateParameters.OrderByParameter), err) } if err = s.populateCriteriaQuerySelector(ctx, ns, opts); err != nil { - return response.NewParameterError(ns.View.Name, selectorParameters.CriteriaParameter.Name, err) + return response.NewParameterError(ns.View.Name, selectorParameterName(selectorParameters.CriteriaParameter, view.QueryStateParameters.CriteriaParameter), err) } if err = s.populatePageQuerySelector(ctx, ns, opts); err != nil { - return response.NewParameterError(ns.View.Name, selectorParameters.PageParameter.Name, err) + return response.NewParameterError(ns.View.Name, selectorParameterName(selectorParameters.PageParameter, view.QueryStateParameters.PageParameter), err) } // Apply injected selector last so it takes precedence over request-derived values, @@ -113,6 +114,16 @@ func (s *Session) setQuerySelector(ctx context.Context, ns *view.NamespaceView, return nil } +func selectorParameterName(parameter, fallback *state.Parameter) string { + if parameter != nil && parameter.Name != "" { + return parameter.Name + } + if fallback != nil && fallback.Name != "" { + return fallback.Name + } + return "" +} + func (s *Session) applyInjectedQuerySelector(ns *view.NamespaceView, selector *view.Statelet, injected *hstate.NamedQuerySelector) error { if injected == nil || selector == nil { return nil diff --git a/service/session/selector_test.go b/service/session/selector_test.go new file mode 100644 index 000000000..4907e304e --- /dev/null +++ b/service/session/selector_test.go @@ -0,0 +1,54 @@ +package session + +import ( + "context" + "net/http" + "reflect" + "testing" + + "github.com/viant/datly/repository" + "github.com/viant/datly/view" + vstate "github.com/viant/datly/view/state" + "github.com/viant/datly/view/state/kind/locator" +) + +func TestSessionBind_QuerySelectorErrorDoesNotPanicWithoutCustomParameters(t *testing.T) { + ctx := context.Background() + resource := view.NewResource(nil) + aView := &view.View{ + Name: "v", + Mode: view.ModeQuery, + Selector: &view.Config{ + Constraints: &view.Constraints{}, + }, + } + aView.SetResource(resource) + aView.Template = &view.Template{Schema: vstate.NewSchema(reflect.TypeOf(struct{ Dummy int }{}))} + if err := aView.Template.Init(ctx, resource, aView); err != nil { + t.Fatalf("failed to init template: %v", err) + } + if err := aView.Selector.Init(ctx, resource, aView); err != nil { + t.Fatalf("failed to init selector: %v", err) + } + + component := &repository.Component{View: aView} + outputType, err := vstate.NewType( + vstate.WithSchema(vstate.NewSchema(reflect.TypeOf(struct{ X int }{}))), + vstate.WithResource(aView.Resource()), + ) + if err != nil { + t.Fatalf("failed to build component output type: %v", err) + } + component.Output.Type = *outputType + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/?_orderby=id", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + + sess := New(aView, WithComponent(component), WithLocatorOptions(locator.WithRequest(req))) + err = sess.SetViewState(ctx, aView) + if err == nil { + t.Fatal("expected query selector error") + } +} diff --git a/service/session/state.go b/service/session/state.go index 5779250fc..9e182c61d 100644 --- a/service/session/state.go +++ b/service/session/state.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "os" "reflect" "strings" "sync" @@ -219,6 +220,16 @@ func (s *Session) ViewOptions(aView *view.View, opts ...Option) *Options { var parameters state.NamedParameters if aView.Template != nil { parameters = aView.Template.Parameters.Index() + if aView.Template.UseResourceParameterLookup && aView.GetResource() != nil { + merged := state.NamedParameters{} + for k, v := range aView.GetResource().NamedParameters() { + merged[k] = v + } + for k, v := range parameters { + merged[k] = v + } + parameters = merged + } } viewOptions.kindLocator = s.kindLocator.With(s.viewLookupOptions(aView, parameters, viewOptions)...) @@ -539,6 +550,9 @@ func (s *Session) ensureValidValue(value interface{}, parameter *state.Parameter } return converted.Interface(), nil } + if wrapped, ok := wrapComponentResult(parameter, value, valueType, rawSrcType, rawDestType, destIsPtr); ok { + return wrapped, nil + } if options.shallReportNotAssignable() { fmt.Printf("parameter %v is not directly assignable from %s:(%s)\nsrc:%s \ndst:%s\n", parameter.Name, parameter.In.Kind, parameter.In.Name, valueType.String(), destType.String()) @@ -563,6 +577,43 @@ func (s *Session) ensureValidValue(value interface{}, parameter *state.Parameter return value, nil } +func wrapComponentResult(parameter *state.Parameter, value interface{}, valueType, rawSrcType, rawDestType reflect.Type, destIsPtr bool) (interface{}, bool) { + if parameter == nil || parameter.In == nil || parameter.In.Kind != state.KindComponent { + return nil, false + } + if rawSrcType.Kind() != reflect.Struct || rawDestType.Kind() != reflect.Struct { + return nil, false + } + field, ok := rawDestType.FieldByName("Data") + if !ok { + return nil, false + } + fieldType := field.Type + srcValue := reflect.ValueOf(value) + if valueType.Kind() == reflect.Ptr { + if srcValue.IsNil() { + return nil, true + } + } + if !valueType.AssignableTo(fieldType) { + if valueType.Kind() == reflect.Ptr && valueType.Elem().AssignableTo(fieldType) { + srcValue = srcValue.Elem() + } else if valueType.Kind() != reflect.Ptr && reflect.PointerTo(valueType).AssignableTo(fieldType) { + ptr := reflect.New(valueType) + ptr.Elem().Set(srcValue) + srcValue = ptr + } else { + return nil, false + } + } + target := reflect.New(rawDestType) + target.Elem().FieldByIndex(field.Index).Set(srcValue) + if destIsPtr { + return target.Interface(), true + } + return target.Elem().Interface(), true +} + func ensureAssignable(fieldName string, destFieldType reflect.Type, srcFieldType reflect.Type) bool { switch destFieldType.Kind() { case reflect.Slice: @@ -722,6 +773,15 @@ func (s *Session) lookupValue(ctx context.Context, parameter *state.Parameter, o func (s *Session) adjustAndCache(ctx context.Context, parameter *state.Parameter, opts *Options, has bool, value interface{}, cachable bool) (interface{}, bool, error) { var err error + if os.Getenv("DATLY_DEBUG_ADJUST") == "1" { + fmt.Printf("[ADJUST DEBUG][start] param=%s kind=%s has=%v value=%T outputType=%v schemaType=%v\n", + parameter.Name, parameter.In.Kind, has, value, parameter.OutputType(), func() reflect.Type { + if parameter.Schema == nil { + return nil + } + return parameter.Schema.Type() + }()) + } if !has && parameter.Value != nil { has = true value = parameter.Value @@ -732,6 +792,9 @@ func (s *Session) adjustAndCache(ctx context.Context, parameter *state.Parameter if value, err = s.adjustValue(parameter, value); err != nil { return nil, false, err } + if os.Getenv("DATLY_DEBUG_ADJUST") == "1" { + fmt.Printf("[ADJUST DEBUG][post-adjust] param=%s value=%T\n", parameter.Name, value) + } if parameter.Output != nil { // Defensive: ensure codec is initialized before Transform. if !parameter.Output.Initialized() { @@ -745,6 +808,9 @@ func (s *Session) adjustAndCache(ctx context.Context, parameter *state.Parameter return nil, false, fmt.Errorf("failed to transform %s with %s: %v, %w", parameter.Name, parameter.Output.Name, value, err) } value = transformed + if os.Getenv("DATLY_DEBUG_ADJUST") == "1" { + fmt.Printf("[ADJUST DEBUG][post-transform] param=%s value=%T\n", parameter.Name, value) + } } if has && err == nil && cachable { s.setValue(parameter, value) diff --git a/service/session/state_test.go b/service/session/state_test.go index 497d6aca6..8c4fac8e8 100644 --- a/service/session/state_test.go +++ b/service/session/state_test.go @@ -288,4 +288,36 @@ func TestSessionEnsureValidValue_Transitions(t *testing.T) { t.Fatalf("expected B=%d, got %d", *original.B, gotB.Elem().Int()) } }) + + t.Run("component_result_wraps_into_data_holder", func(t *testing.T) { + type componentRow struct { + IsReadOnly int + } + type componentHolder struct { + Data *componentRow + } + + value := &componentRow{IsReadOnly: 1} + parameter := &state.Parameter{ + Name: "Auth", + In: state.NewComponent("GET:/auth"), + Schema: state.NewSchema(reflect.TypeOf(componentHolder{})), + } + selector := newSelector(t, reflect.TypeOf(componentHolder{})) + sess := &Session{} + opts := NewOptions(WithReportNotAssignable(false)) + + got, err := sess.ensureValidValue(value, parameter, selector, opts) + if err != nil { + t.Fatalf("ensureValidValue error: %v", err) + } + + holder, ok := got.(componentHolder) + if !ok { + t.Fatalf("expected componentHolder, got %T", got) + } + if holder.Data == nil || holder.Data.IsReadOnly != 1 { + t.Fatalf("expected wrapped component result, got %#v", got) + } + }) } diff --git a/service/session/stater.go b/service/session/stater.go index 5b0b0d6ab..392a6d3a8 100644 --- a/service/session/stater.go +++ b/service/session/stater.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "os" "reflect" "runtime/debug" @@ -46,6 +47,9 @@ func (s *Session) Bind(ctx context.Context, dest interface{}, opts ...hstate.Opt defer func() { if r := recover(); r != nil { panicMsg := fmt.Sprintf("Panic occurred: %v, Stack trace: %v", r, string(debug.Stack())) + if os.Getenv("DATLY_DEBUG_BIND") == "1" { + fmt.Printf("[BIND DEBUG] %s\n", panicMsg) + } logger := s.Logger() if logger == nil { panic(panicMsg) diff --git a/testutil/shapeparity/scan.go b/testutil/shapeparity/scan.go new file mode 100644 index 000000000..97acae814 --- /dev/null +++ b/testutil/shapeparity/scan.go @@ -0,0 +1,100 @@ +package shapeparity + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/viant/afs" + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/command" + "github.com/viant/datly/cmd/options" + dqlscan "github.com/viant/datly/repository/shape/dql/scan" +) + +// ScanDQL runs the legacy translator into a temporary repository, then feeds the +// generated route YAML back through the shape scanner so parity tests can +// compare the legacy YAML contract against shape IR. +func ScanDQL(ctx context.Context, req *dqlscan.Request) (*dqlscan.Result, error) { + if req == nil { + return nil, fmt.Errorf("shape parity scan request was nil") + } + dqlURL := strings.TrimSpace(req.DQLURL) + if dqlURL == "" { + return nil, fmt.Errorf("shape parity scan request DQLURL was empty") + } + fs := afs.New() + dqlBytes, err := fs.DownloadWithURL(ctx, dqlURL) + if err != nil { + return nil, fmt.Errorf("failed to read DQL %s: %w", dqlURL, err) + } + tmpRepo, err := os.MkdirTemp("", "datly-shapeparity-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp repository: %w", err) + } + defer os.RemoveAll(tmpRepo) + + projectRoot := inferProjectRoot(req, dqlURL) + modulePrefix := strings.Trim(strings.TrimSpace(req.ModulePrefix), "/") + apiPrefix := strings.TrimSpace(req.APIPrefix) + if apiPrefix == "" { + apiPrefix = "/v1/api" + } + repoOpts := options.Repository{ + RepositoryURL: tmpRepo, + ProjectURL: projectRoot, + APIPrefix: apiPrefix, + } + repoOpts.Connectors = append(repoOpts.Connectors, req.Connectors...) + if cfgURL := strings.TrimSpace(req.ConfigURL); cfgURL != "" { + repoOpts.Configs.Append(cfgURL) + } + opts := &options.Options{ + Translate: &options.Translate{ + Rule: options.Rule{ + Project: projectRoot, + ModulePrefix: modulePrefix, + Source: []string{dqlURL}, + ModuleLocation: func() string { + if req.Repository != "" { + return filepath.Join(req.Repository, "pkg") + } + return filepath.Join(projectRoot, "pkg") + }(), + Engine: options.EngineLegacy, + }, + Repository: repoOpts, + }, + } + if err = opts.Init(ctx); err != nil { + return nil, fmt.Errorf("failed to initialise legacy translate options: %w", err) + } + if err = command.New().Translate(ctx, opts); err != nil { + return nil, fmt.Errorf("failed to translate DQL %s with legacy pipeline: %w", dqlURL, err) + } + + ruleName := strings.TrimSuffix(filepath.Base(url.Path(dqlURL)), filepath.Ext(url.Path(dqlURL))) + routeYAMLURL := filepath.Join(tmpRepo, "Datly", "routes") + if modulePrefix != "" { + routeYAMLURL = filepath.Join(routeYAMLURL, filepath.FromSlash(modulePrefix)) + } + routeYAMLURL = filepath.Join(routeYAMLURL, ruleName+".yaml") + yamlBytes, err := fs.DownloadWithURL(ctx, routeYAMLURL) + if err != nil { + return nil, fmt.Errorf("failed to read generated route YAML %s: %w", routeYAMLURL, err) + } + return dqlscan.New().Result(ruleName, yamlBytes, string(dqlBytes), req) +} + +func inferProjectRoot(req *dqlscan.Request, dqlURL string) string { + if repositoryRoot := strings.TrimSpace(req.Repository); repositoryRoot != "" { + return filepath.Dir(filepath.Clean(repositoryRoot)) + } + if scheme := url.Scheme(dqlURL, file.Scheme); scheme != "" && scheme != file.Scheme { + return filepath.Dir(url.Path(dqlURL)) + } + return filepath.Dir(filepath.Clean(url.Path(dqlURL))) +} diff --git a/view/column.go b/view/column.go index 90bd0fad6..9a20680a9 100644 --- a/view/column.go +++ b/view/column.go @@ -10,6 +10,7 @@ import ( "github.com/viant/tagly/format/text" "github.com/viant/xreflect" "reflect" + "strconv" "strings" ) @@ -21,7 +22,9 @@ type ( Tag string `json:",omitempty"` Expression string `json:",omitempty"` + Aggregate bool `json:",omitempty"` Filterable bool `json:",omitempty"` + Groupable bool `json:",omitempty"` Nullable bool `json:",omitempty"` Default string `json:",omitempty"` FormatTag *format.Tag `json:",omitempty"` @@ -34,6 +37,7 @@ type ( field *reflect.StructField _initialized bool _fieldName string + _groupableSet bool } ColumnOption func(c *Column) ) @@ -79,6 +83,9 @@ func (c *Column) Init(resource state.Resource, caseFormat text.CaseFormat, allow if c.Name == "" { return fmt.Errorf("column name was empty") } + if err := c.initGroupable(); err != nil { + return err + } err := c.EnsureType(resource.LookupType()) if err != nil { return err @@ -100,6 +107,23 @@ func (c *Column) Init(resource state.Resource, caseFormat text.CaseFormat, allow return nil } +func (c *Column) initGroupable() error { + if c._groupableSet || c.Tag == "" { + return nil + } + value, ok := reflect.StructTag(c.Tag).Lookup("groupable") + if !ok { + return nil + } + groupable, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid groupable tag for column %s: %w", c.Name, err) + } + c.Groupable = groupable + c._groupableSet = true + return nil +} + func (c *Column) EnsureType(lookupType xreflect.LookupType) error { if c.rType != nil && c.rType != xreflect.InterfaceType { return nil @@ -164,6 +188,10 @@ func (c *Column) defaultValue(rType reflect.Type) string { } } +func (c *Column) DefaultValue() string { + return c.defaultValue(c.rType) +} + func (c *Column) FieldName() string { return c._fieldName } @@ -199,6 +227,10 @@ func (c *Column) ApplyConfig(config *ColumnConfig) { if config.Default != nil { c.Default = *config.Default } + if config.Groupable != nil { + c.Groupable = *config.Groupable + c._groupableSet = true + } c._initialized = false } @@ -237,6 +269,7 @@ type ( Codec *state.Codec `json:",omitempty"` DataType *string `json:",omitempty"` Required *bool `json:",omitempty"` + Groupable *bool `json:",omitempty"` Format *string `json:",omitempty"` Tag *string `json:",omitempty"` Default *string `json:",omitempty"` diff --git a/view/column_lookup_test.go b/view/column_lookup_test.go new file mode 100644 index 000000000..db598661d --- /dev/null +++ b/view/column_lookup_test.go @@ -0,0 +1,27 @@ +package view + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestView_ColumnByName_UsesIndexedLookup(t *testing.T) { + aView := NewView("disqualified", "disqualified", + WithConnector(NewConnector("test", "sqlite3", ":memory:")), + WithColumns(Columns{ + &Column{Name: "TAXONOMY_ID", DataType: "int", Tag: `source:"SEGMENT_ID"`}, + &Column{Name: "IS_DISQUALIFIED", DataType: "int"}, + }), + ) + require.NoError(t, aView.Init(context.Background(), EmptyResource())) + + column, ok := aView.ColumnByName("SEGMENT_ID") + require.True(t, ok) + require.Equal(t, "TAXONOMY_ID", column.Name) + + column, ok = aView.ColumnByName("taxonomy_id") + require.True(t, ok) + require.Equal(t, "TAXONOMY_ID", column.Name) +} diff --git a/view/columns.go b/view/columns.go index 44b5c8f19..bb9d85dcf 100644 --- a/view/columns.go +++ b/view/columns.go @@ -197,6 +197,7 @@ func NewColumns(columns sqlparser.Columns, config map[string]*ColumnConfig) Colu } name = item.Identity() column := NewColumn(name, item.Type, item.RawType, item.IsNullable, WithColumnTag(item.Tag)) + column.Aggregate = isAggregateProjection(item.Expression) if item.Name != item.Alias && item.Alias != "" && item.Name != "" { column.Tag += fmt.Sprintf(`source:"%v"`, item.Name) } @@ -210,3 +211,17 @@ func NewColumns(columns sqlparser.Columns, config map[string]*ColumnConfig) Colu } return result } + +func isAggregateProjection(expression string) bool { + expression = strings.ToLower(strings.TrimSpace(expression)) + switch { + case strings.Contains(expression, "count("), + strings.Contains(expression, "sum("), + strings.Contains(expression, "avg("), + strings.Contains(expression, "min("), + strings.Contains(expression, "max("): + return true + default: + return false + } +} diff --git a/view/config.go b/view/config.go index 5dd70f6de..bf1dfe2c8 100644 --- a/view/config.go +++ b/view/config.go @@ -32,7 +32,7 @@ var QueryStateParameters = &Config{ PageParameter: &state.Parameter{Name: "Page", In: state.NewQueryLocation(PageQuery), Schema: state.NewSchema(xreflect.IntType)}, FieldsParameter: &state.Parameter{Name: "Fields", In: state.NewQueryLocation(FieldsQuery), Schema: state.NewSchema(stringsType)}, OrderByParameter: &state.Parameter{Name: "OrderBy", In: state.NewQueryLocation(OrderByQuery), Schema: state.NewSchema(stringsType)}, - CriteriaParameter: &state.Parameter{Name: "Criteria", In: state.NewQueryLocation(OrderByQuery), Schema: state.NewSchema(xreflect.StringType)}, + CriteriaParameter: &state.Parameter{Name: "Criteria", In: state.NewQueryLocation(CriteriaQuery), Schema: state.NewSchema(xreflect.StringType)}, SyncFlagParameter: &state.Parameter{Name: "SyncFlag", Cacheable: &trueValue, In: state.NewState(SyncFlag), Schema: state.NewSchema(boolType)}, ContentFormatParameter: &state.Parameter{Name: "ContentFormat", In: state.NewQueryLocation(ContentFormat), Schema: state.NewSchema(xreflect.StringType)}, } diff --git a/view/config_test.go b/view/config_test.go new file mode 100644 index 000000000..53f081717 --- /dev/null +++ b/view/config_test.go @@ -0,0 +1,9 @@ +package view + +import "testing" + +func TestQueryStateParameters_CriteriaParameterUsesCriteriaQuery(t *testing.T) { + if QueryStateParameters.CriteriaParameter == nil || QueryStateParameters.CriteriaParameter.In.Name != CriteriaQuery { + t.Fatalf("expected CriteriaParameter query name %q, got %#v", CriteriaQuery, QueryStateParameters.CriteriaParameter) + } +} diff --git a/view/groupable_test.go b/view/groupable_test.go new file mode 100644 index 000000000..0b2716713 --- /dev/null +++ b/view/groupable_test.go @@ -0,0 +1,46 @@ +package view + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/tagly/format/text" +) + +func TestColumn_Init_GroupableTag(t *testing.T) { + column := &Column{ + Name: "region", + DataType: "string", + Tag: `groupable:"true"`, + } + + err := column.Init(NewResources(EmptyResource(), &View{}), text.CaseFormatLowerUnderscore, true) + require.NoError(t, err) + require.True(t, column.Groupable) +} + +func TestView_IsGroupable(t *testing.T) { + groupable := &Column{Name: "region", Groupable: true} + metric := &Column{Name: "total"} + index := Columns{groupable, metric}.Index(text.CaseFormatLowerUnderscore) + index.RegisterWithName("Region", groupable) + + aView := &View{ + Columns: []*Column{groupable, metric}, + _columns: index, + } + + require.True(t, aView.IsGroupable("region")) + require.True(t, aView.IsGroupable("Region")) + require.False(t, aView.IsGroupable("total")) + require.False(t, aView.IsGroupable("missing")) +} + +func TestView_inherit_Groupable(t *testing.T) { + child := &View{} + parent := &View{Groupable: true} + + err := child.inherit(parent) + require.NoError(t, err) + require.True(t, child.Groupable) +} diff --git a/view/grouped_relation_compat_test.go b/view/grouped_relation_compat_test.go new file mode 100644 index 000000000..2d1cbfee3 --- /dev/null +++ b/view/grouped_relation_compat_test.go @@ -0,0 +1,56 @@ +package view + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view/state" + "github.com/viant/tagly/format/text" +) + +func TestView_EnsureColumns_UsesTypedSchemaForGroupedRelationAlias(t *testing.T) { + ctx := context.Background() + resource := NewResource(nil) + resource.Types = []*TypeDefinition{ + { + Name: "DisqualifiedView", + Package: "taxonomy", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/taxonomy", + DataType: `struct{TaxonomyId string ` + "`sqlx:\"TAXONOMY_ID\" source:\"SEGMENT_ID\" velty:\"names=TAXONOMY_ID|TaxonomyId\"`" + `; IsDisqualified int ` + "`sqlx:\"IS_DISQUALIFIED\" internal:\"true\" json:\"-\" velty:\"names=IS_DISQUALIFIED|IsDisqualified\"`" + `; }`, + }, + } + require.NoError(t, resource.Init(ctx)) + + aView := &View{ + Name: "disqualified", + Table: "CI_TAXONOMY_DISQUALIFIED", + Alias: "t", + Mode: ModeQuery, + Schema: &state.Schema{Name: "DisqualifiedView", Package: "taxonomy", Cardinality: state.Many}, + Template: &Template{Source: "SELECT dq.SEGMENT_ID AS TAXONOMY_ID, 1 AS IS_DISQUALIFIED FROM CI_TAXONOMY_DISQUALIFIED dq GROUP BY dq.SEGMENT_ID"}, + ColumnsConfig: map[string]*ColumnConfig{ + "IS_DISQUALIFIED": { + Name: "IS_DISQUALIFIED", + Tag: ptrString(`json:"-" internal:"true"`), + }, + }, + } + + require.NoError(t, aView.ensureColumns(ctx, resource)) + require.Len(t, aView.Columns, 2) + aView.CaseFormat = text.CaseFormatLowerUnderscore + aView._columns = Columns(aView.Columns).Index(aView.CaseFormat) + + column, ok := aView.ColumnByName("TaxonomyId") + require.True(t, ok) + require.Equal(t, "TAXONOMY_ID", column.Name) + + column, ok = aView.ColumnByName("SEGMENT_ID") + require.True(t, ok) + require.Equal(t, "TAXONOMY_ID", column.Name) +} + +func ptrString(value string) *string { + return &value +} diff --git a/view/option.go b/view/option.go index fb776b94d..d630f7145 100644 --- a/view/option.go +++ b/view/option.go @@ -50,6 +50,56 @@ func WithColumns(columns Columns) Option { } } +func WithGroupable(groupable bool) Option { + return func(v *View) error { + v.Groupable = groupable + return nil + } +} + +func WithSummary(summary *TemplateSummary) Option { + return func(v *View) error { + v.EnsureTemplate() + v.Template.Summary = summary + return nil + } +} + +func WithSummaryURI(sourceURL string) Option { + return func(v *View) error { + v.EnsureTemplate() + if v.Template.Summary == nil { + v.Template.Summary = &TemplateSummary{} + } + v.Template.Summary.SourceURL = sourceURL + return nil + } +} + +func WithTemplateParameterStateType(enabled bool) Option { + return func(v *View) error { + v.EnsureTemplate() + v.Template.UseParameterStateType = enabled + return nil + } +} + +func WithDeclaredTemplateParametersOnly(enabled bool) Option { + return func(v *View) error { + v.EnsureTemplate() + v.Template.DeclaredParametersOnly = enabled + return nil + } +} + +func WithResourceParameterLookup(enabled bool) Option { + return func(v *View) error { + v.EnsureTemplate() + v.Template.UseResourceParameterLookup = enabled + return nil + } +} + // WithFS creates fs options func WithFS(fs *embed.FS) Option { return func(v *View) error { diff --git a/view/option_test.go b/view/option_test.go new file mode 100644 index 000000000..83d802dc4 --- /dev/null +++ b/view/option_test.go @@ -0,0 +1,83 @@ +package view + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view/state" +) + +func TestWithSummary(t *testing.T) { + aView := NewView("vendor", "") + err := WithSummary(&TemplateSummary{ + Name: "Meta", + SourceURL: "vendor/vendor_summary.sql", + })(aView) + require.NoError(t, err) + require.NotNil(t, aView.Template) + require.NotNil(t, aView.Template.Summary) + require.Equal(t, "Meta", aView.Template.Summary.Name) + require.Equal(t, "vendor/vendor_summary.sql", aView.Template.Summary.SourceURL) +} + +func TestWithSummaryURI(t *testing.T) { + aView := NewView("vendor", "") + err := WithSummaryURI("vendor/vendor_summary.sql")(aView) + require.NoError(t, err) + require.NotNil(t, aView.Template) + require.NotNil(t, aView.Template.Summary) + require.Equal(t, "vendor/vendor_summary.sql", aView.Template.Summary.SourceURL) +} + +func TestWithTemplateParameterStateType(t *testing.T) { + type input struct { + Foos *struct { + ID int + } + } + + resource := EmptyResource() + aView := &View{ + Name: "foos", + Table: "FOOS", + Schema: state.NewSchema(reflect.TypeOf(&input{})), + _resource: resource, + } + aView.Template = NewTemplate( + `$CurFoosId.Values`, + WithTemplateParameters( + &state.Parameter{ + Name: "Foos", + In: state.NewBodyLocation(""), + Schema: state.NewSchema(reflect.TypeOf(&struct{ ID int }{})), + Tag: `anonymous:"true"`, + }, + &state.Parameter{ + Name: "CurFoosId", + In: state.NewParameterLocation("Foos"), + Schema: state.NewSchema(reflect.TypeOf(&struct{ Values []int }{})), + }, + ), + ) + require.NoError(t, WithTemplateParameterStateType(true)(aView)) + require.NoError(t, aView.Template.Init(context.Background(), resource, aView)) + require.NotNil(t, aView.Template.StateType()) + require.NotNil(t, aView.Template.StateType().Lookup("Foos")) + require.NotNil(t, aView.Template.StateType().Lookup("CurFoosId")) +} + +func TestWithDeclaredTemplateParametersOnly(t *testing.T) { + aView := NewView("vendor", "") + require.NoError(t, WithDeclaredTemplateParametersOnly(true)(aView)) + require.NotNil(t, aView.Template) + require.True(t, aView.Template.DeclaredParametersOnly) +} + +func TestWithResourceParameterLookup(t *testing.T) { + aView := NewView("vendor", "") + require.NoError(t, WithResourceParameterLookup(true)(aView)) + require.NotNil(t, aView.Template) + require.True(t, aView.Template.UseResourceParameterLookup) +} diff --git a/view/state/kind/locator/data.go b/view/state/kind/locator/data.go index db35876c5..ec5c6d135 100644 --- a/view/state/kind/locator/data.go +++ b/view/state/kind/locator/data.go @@ -6,6 +6,7 @@ import ( "github.com/viant/datly/view" "github.com/viant/datly/view/state" "github.com/viant/datly/view/state/kind" + "os" "reflect" ) @@ -18,19 +19,43 @@ func (p *DataView) Names() []string { return nil } -func (p *DataView) Value(ctx context.Context, _ reflect.Type, name string) (interface{}, bool, error) { +func (p *DataView) Value(ctx context.Context, rType reflect.Type, name string) (interface{}, bool, error) { aView, ok := p.Views[name] if !ok { return nil, false, fmt.Errorf("failed to lookup view: %v", name) } + if os.Getenv("DATLY_DEBUG_VIEW_LOCATOR") == "1" { + fmt.Printf("[VIEW LOCATOR] name=%s schema=%v card=%s slice=%v\n", name, func() reflect.Type { + if aView.Schema == nil { + return nil + } + return aView.Schema.Type() + }(), func() state.Cardinality { + if aView.Schema == nil { + return "" + } + return aView.Schema.Cardinality + }(), func() reflect.Type { + if aView.Schema == nil { + return nil + } + return aView.Schema.SliceType() + }()) + } sliceValue := reflect.New(aView.Schema.SliceType()) destSlicePtr := sliceValue.Interface() err := p.ReadInto(ctx, destSlicePtr, aView) if err != nil { + if os.Getenv("DATLY_DEBUG_VIEW_LOCATOR") == "1" { + fmt.Printf("[VIEW LOCATOR] name=%s readIntoErr=%v dest=%T\n", name, err, destSlicePtr) + } return nil, false, err } + if os.Getenv("DATLY_DEBUG_VIEW_LOCATOR") == "1" { + fmt.Printf("[VIEW LOCATOR] name=%s len=%d dest=%T\n", name, sliceValue.Elem().Len(), destSlicePtr) + } - if aView.Schema.Cardinality == state.One { + if shouldReturnSingleValue(aView, rType) { switch sliceValue.Elem().Len() { case 0: return nil, true, nil @@ -43,6 +68,24 @@ func (p *DataView) Value(ctx context.Context, _ reflect.Type, name string) (inte return sliceValue.Elem().Interface(), true, err } +func shouldReturnSingleValue(aView *view.View, rType reflect.Type) bool { + if aView != nil && aView.Schema != nil && aView.Schema.Cardinality == state.One { + return true + } + if rType == nil { + return false + } + for rType.Kind() == reflect.Interface { + rType = rType.Elem() + } + switch rType.Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + return false + default: + return true + } +} + func NewView(opts ...Option) (kind.Locator, error) { options := NewOptions(opts) if options.Views == nil { diff --git a/view/state/kind/locator/data_test.go b/view/state/kind/locator/data_test.go new file mode 100644 index 000000000..37ee43c14 --- /dev/null +++ b/view/state/kind/locator/data_test.go @@ -0,0 +1,62 @@ +package locator + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type dataViewLocatorRecord struct { + ID int +} + +func TestDataView_Value_UsesRequestedScalarTypeToUnwrapSingleResult(t *testing.T) { + aView := &view.View{ + Name: "CurFoos", + Schema: state.NewSchema(reflect.TypeOf([]*dataViewLocatorRecord{})), + } + aView.Schema.Cardinality = state.Many + locator := &DataView{ + Views: view.NamedViews{"CurFoos": aView}, + ReadInto: func(ctx context.Context, dest interface{}, aView *view.View) error { + target := dest.(*[]*dataViewLocatorRecord) + *target = append(*target, &dataViewLocatorRecord{ID: 7}) + return nil + }, + } + + value, ok, err := locator.Value(context.Background(), reflect.TypeOf(&dataViewLocatorRecord{}), "CurFoos") + require.NoError(t, err) + require.True(t, ok) + record, ok := value.(*dataViewLocatorRecord) + require.True(t, ok) + require.Equal(t, 7, record.ID) +} + +func TestDataView_Value_PreservesSliceForSliceTarget(t *testing.T) { + aView := &view.View{ + Name: "CurFoos", + Schema: state.NewSchema(reflect.TypeOf([]*dataViewLocatorRecord{})), + } + aView.Schema.Cardinality = state.Many + locator := &DataView{ + Views: view.NamedViews{"CurFoos": aView}, + ReadInto: func(ctx context.Context, dest interface{}, aView *view.View) error { + target := dest.(*[]*dataViewLocatorRecord) + *target = append(*target, &dataViewLocatorRecord{ID: 7}) + return nil + }, + } + + value, ok, err := locator.Value(context.Background(), reflect.TypeOf([]*dataViewLocatorRecord{}), "CurFoos") + require.NoError(t, err) + require.True(t, ok) + records, ok := value.([]*dataViewLocatorRecord) + require.True(t, ok) + require.Len(t, records, 1) + require.Equal(t, 7, records[0].ID) +} diff --git a/view/state/parameter.go b/view/state/parameter.go index a1ded9294..8bee637d6 100644 --- a/view/state/parameter.go +++ b/view/state/parameter.go @@ -53,6 +53,7 @@ type ( URI string `json:",omitempty" yaml:"URI"` Cacheable *bool `json:",omitempty" yaml:"Cacheable"` Async bool `json:",omitempty" yaml:"Async"` + PreserveSchema bool `json:",omitempty" yaml:"PreserveSchema"` isOutputType bool _timeLayout string _selector *structology.Selector @@ -546,6 +547,11 @@ func (p *Parameter) OutputType() reflect.Type { } func (p *Parameter) initParamBasedParameter(ctx context.Context, resource Resource) error { + if p.Schema != nil { + if p.Schema.Type() != nil { + return nil + } + } if p.Schema.Type() != nil { return nil } diff --git a/view/state/parameter_test.go b/view/state/parameter_test.go new file mode 100644 index 000000000..41f540646 --- /dev/null +++ b/view/state/parameter_test.go @@ -0,0 +1,97 @@ +package state + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/xdatly/codec" + "github.com/viant/xreflect" +) + +type parameterNamedPatchFoos struct { + ID int +} + +type testResource struct { + params map[string]*Parameter +} + +func (t *testResource) LookupParameter(name string) (*Parameter, error) { return t.params[name], nil } +func (t *testResource) AppendParameter(parameter *Parameter) {} +func (t *testResource) ViewSchema(context.Context, string) (*Schema, error) { + return nil, nil +} +func (t *testResource) ViewSchemaPointer(context.Context, string) (*Schema, error) { + return nil, nil +} +func (t *testResource) LookupType() xreflect.LookupType { return nil } +func (t *testResource) LoadText(context.Context, string) (string, error) { + return "", nil +} +func (t *testResource) Codecs() *codec.Registry { return nil } +func (t *testResource) CodecOptions() *codec.Options { return nil } +func (t *testResource) ExpandSubstitutes(text string) string { return text } +func (t *testResource) ReverseSubstitutes(text string) string { return text } +func (t *testResource) EmbedFS() *embed.FS { return nil } +func (t *testResource) SetFSEmbedder(*FSEmbedder) {} + +func TestParameter_initParamBasedParameter_ResolvesSourceSchemaEvenWithExplicitDataType(t *testing.T) { + resource := &testResource{ + params: map[string]*Parameter{ + "Foos": { + Name: "Foos", + In: NewBodyLocation(""), + Schema: NewSchema(reflect.TypeOf(&struct{ ID int }{})), + }, + }, + } + param := &Parameter{ + Name: "CurFoosId", + In: NewParameterLocation("Foos"), + Schema: &Schema{DataType: `*struct { Values []int "json:\",omitempty\"" }`}, + PreserveSchema: true, + } + + require.NoError(t, param.initParamBasedParameter(context.Background(), resource)) + require.NotNil(t, param.Schema) + require.Equal(t, reflect.TypeOf(&struct{ ID int }{}), param.Schema.Type()) +} + +func TestParameters_ReflectType_QualifiedNamedDataTypeResolves(t *testing.T) { + registry := xreflect.NewTypes() + require.NoError(t, registry.Register("FoosView", xreflect.WithPackage("patch_basic_one"), xreflect.WithReflectType(reflect.TypeOf(parameterNamedPatchFoos{})))) + + params := Parameters{ + &Parameter{ + Name: "Foos", + In: NewBodyLocation(""), + Schema: &Schema{Name: "FoosView", Package: "patch_basic_one", DataType: "*patch_basic_one.FoosView", Cardinality: One}, + }, + } + + rType, err := params.ReflectType("patch_basic_one", registry.Lookup) + require.NoError(t, err) + field, ok := rType.FieldByName("Foos") + require.True(t, ok) + require.Equal(t, reflect.TypeOf(¶meterNamedPatchFoos{}), field.Type) +} + +func TestParameter_buildTag_ParamDoesNotOverrideSourceDataType(t *testing.T) { + param := &Parameter{ + Name: "CurFoosId", + In: NewParameterLocation("Foos"), + Schema: &Schema{Name: "CurFoosId", DataType: `*struct { Values []int "json:\",omitempty\"" }`}, + Output: &Codec{ + Name: "structql", + Schema: &Schema{Name: "CurFoosId", DataType: `*struct { Values []int "json:\",omitempty\"" }`}, + }, + } + + tag := string(param.buildTag("CurFoosId")) + require.NotContains(t, tag, `dataType:"`) + require.Contains(t, tag, `kind=param`) + require.Contains(t, tag, `in=Foos`) +} diff --git a/view/state/parameters.go b/view/state/parameters.go index e464ce142..24235aab7 100644 --- a/view/state/parameters.go +++ b/view/state/parameters.go @@ -604,7 +604,9 @@ func (p *Parameter) buildTag(fieldName string) reflect.StructTag { } if p.Output != nil && p.Output.Schema != nil { if p.Output.Schema.TypeName() != p.Schema.TypeName() { - aTag.Parameter.DataType = p.Schema.TypeName() + if p.In == nil || p.In.Kind != KindParam { + aTag.Parameter.DataType = p.Schema.TypeName() + } } } if p.Handler != nil { diff --git a/view/tags/query_selector.go b/view/tags/query_selector.go new file mode 100644 index 000000000..30c61f552 --- /dev/null +++ b/view/tags/query_selector.go @@ -0,0 +1,23 @@ +package tags + +import "strings" + +const QuerySelectorTag = "querySelector" + +// ParseQuerySelector returns the target view alias encoded in querySelector tag. +// Supported forms: +// +// querySelector:"vendor" +// querySelector:"view=vendor" +func ParseQuerySelector(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if key, mapped, ok := strings.Cut(value, "="); ok { + if strings.EqualFold(strings.TrimSpace(key), "view") { + return strings.TrimSpace(mapped) + } + } + return value +} diff --git a/view/tags/view.go b/view/tags/view.go index 3e96b4f5c..0002d3534 100644 --- a/view/tags/view.go +++ b/view/tags/view.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/viant/afs/storage" "github.com/viant/tagly/tags" + "sort" "strconv" "strings" ) @@ -19,6 +20,9 @@ type ( View struct { Name string Table string + SummaryURI string + TypeName string + Dest string CustomTag string Parameters []string //parameter references Connector string @@ -30,6 +34,15 @@ type ( PartitionerType string PartitionedConcurrency int RelationalConcurrency int + Groupable *bool + SelectorNamespace string + SelectorCriteria *bool + SelectorProjection *bool + SelectorOrderBy *bool + SelectorOffset *bool + SelectorPage *bool + SelectorFilterable []string + SelectorOrderByColumns map[string]string } ) @@ -60,6 +73,12 @@ func (t *Tag) updateView(key string, value string) error { tag.Limit = &limit case "table": tag.Table = strings.TrimSpace(value) + case "summaryuri": + tag.SummaryURI = strings.TrimSpace(value) + case "type": + tag.TypeName = strings.TrimSpace(value) + case "dest": + tag.Dest = strings.TrimSpace(value) case "connector": tag.Connector = strings.TrimSpace(value) case "partitioner": @@ -81,6 +100,24 @@ func (t *Tag) updateView(key string, value string) error { for _, parameter := range strings.Split(parameters, ",") { tag.Parameters = append(tag.Parameters, strings.TrimSpace(parameter)) } + case "groupable": + tag.Groupable = parseBoolPointer(value) + case "selectornamespace": + tag.SelectorNamespace = strings.TrimSpace(value) + case "selectorcriteria": + tag.SelectorCriteria = parseBoolPointer(value) + case "selectorprojection": + tag.SelectorProjection = parseBoolPointer(value) + case "selectororderby": + tag.SelectorOrderBy = parseBoolPointer(value) + case "selectoroffset": + tag.SelectorOffset = parseBoolPointer(value) + case "selectorpage": + tag.SelectorPage = parseBoolPointer(value) + case "selectorfilterable": + tag.SelectorFilterable = parseTagList(value) + case "selectororderbycolumns": + tag.SelectorOrderByColumns = parseTagMap(value) default: return fmt.Errorf("unsupported view tag option: '%s'", key) } @@ -105,6 +142,9 @@ func (v *View) Tag() *tags.Tag { appendNonEmpty(builder, "limit", strconv.Itoa(*v.Limit)) } appendNonEmpty(builder, "table", v.Table) + appendNonEmpty(builder, "summaryURI", v.SummaryURI) + appendNonEmpty(builder, "type", v.TypeName) + appendNonEmpty(builder, "dest", v.Dest) if v.Batch > 0 { appendNonEmpty(builder, "batch", strconv.Itoa(v.Batch)) } @@ -126,6 +166,28 @@ func (v *View) Tag() *tags.Tag { appendNonEmpty(builder, "concurrency", strconv.Itoa(v.PartitionedConcurrency)) } } + appendBool(builder, "groupable", v.Groupable) + appendNonEmpty(builder, "selectorNamespace", v.SelectorNamespace) + appendBool(builder, "selectorCriteria", v.SelectorCriteria) + appendBool(builder, "selectorProjection", v.SelectorProjection) + appendBool(builder, "selectorOrderBy", v.SelectorOrderBy) + appendBool(builder, "selectorOffset", v.SelectorOffset) + appendBool(builder, "selectorPage", v.SelectorPage) + if len(v.SelectorFilterable) > 0 { + appendNonEmpty(builder, "selectorFilterable", "{"+strings.Join(v.SelectorFilterable, ",")+"}") + } + if len(v.SelectorOrderByColumns) > 0 { + keys := make([]string, 0, len(v.SelectorOrderByColumns)) + for key := range v.SelectorOrderByColumns { + keys = append(keys, key) + } + sort.Strings(keys) + pairs := make([]string, 0, len(keys)) + for _, key := range keys { + pairs = append(pairs, key+":"+v.SelectorOrderByColumns[key]) + } + appendNonEmpty(builder, "selectorOrderByColumns", "{"+strings.Join(pairs, ",")+"}") + } return &tags.Tag{Name: ViewTag, Values: tags.Values(builder.String())} } @@ -138,3 +200,64 @@ func appendNonEmpty(builder *strings.Builder, key, value string) { builder.WriteString("=") builder.WriteString(value) } + +func appendBool(builder *strings.Builder, key string, value *bool) { + if value == nil { + return + } + appendNonEmpty(builder, key, strconv.FormatBool(*value)) +} + +func parseBoolPointer(value string) *bool { + v := true + switch strings.ToLower(strings.TrimSpace(value)) { + case "", "true", "1": + v = true + case "false", "0": + v = false + } + return &v +} + +func parseTagList(value string) []string { + value = strings.Trim(value, "{}'\" ") + if value == "" { + return nil + } + items := strings.Split(value, ",") + ret := make([]string, 0, len(items)) + for _, item := range items { + item = strings.TrimSpace(item) + if item != "" { + ret = append(ret, item) + } + } + return ret +} + +func parseTagMap(value string) map[string]string { + value = strings.Trim(value, "{}'\" ") + if value == "" { + return nil + } + ret := map[string]string{} + for _, item := range strings.Split(value, ",") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + key, mapped, ok := strings.Cut(item, ":") + if !ok { + continue + } + key = strings.TrimSpace(key) + mapped = strings.TrimSpace(mapped) + if key != "" && mapped != "" { + ret[key] = mapped + } + } + if len(ret) == 0 { + return nil + } + return ret +} diff --git a/view/tags/view_test.go b/view/tags/view_test.go index 1127cf66c..e84ded4c1 100644 --- a/view/tags/view_test.go +++ b/view/tags/view_test.go @@ -37,6 +37,29 @@ func TestTag_updateView(t *testing.T) { tag: `view:"foo,table=FOO,connector=dev,parameters={P1,P2}"`, expectView: &View{Name: "foo", Table: "FOO", Connector: "dev", Parameters: []string{"P1", "P2"}}, }, + { + description: "selector metadata view", + tag: `view:"foo,table=FOO,groupable=true,selectorNamespace=ve,selectorCriteria=true,selectorProjection=true,selectorOrderBy=true,selectorOffset=true,selectorFilterable={*},selectorOrderByColumns={accountId:ACCOUNT_ID,userCreated:USER_CREATED}"`, + expectView: &View{ + Name: "foo", + Table: "FOO", + Groupable: boolPtr(true), + SelectorNamespace: "ve", + SelectorCriteria: boolPtr(true), + SelectorProjection: boolPtr(true), + SelectorOrderBy: boolPtr(true), + SelectorOffset: boolPtr(true), + SelectorFilterable: []string{"*"}, + SelectorOrderByColumns: map[string]string{"accountId": "ACCOUNT_ID", "userCreated": "USER_CREATED"}, + }, + expectTag: "foo,table=FOO,groupable=true,selectorNamespace=ve,selectorCriteria=true,selectorProjection=true,selectorOrderBy=true,selectorOffset=true,selectorFilterable={*},selectorOrderByColumns={accountId:ACCOUNT_ID,userCreated:USER_CREATED}", + }, + { + description: "summary uri view", + tag: `view:"foo,table=FOO,summaryURI=testdata/foo_summary.sql"`, + expectView: &View{Name: "foo", Table: "FOO", SummaryURI: "testdata/foo_summary.sql"}, + expectTag: "foo,table=FOO,summaryURI=testdata/foo_summary.sql", + }, } for _, testCase := range testCases { @@ -55,3 +78,7 @@ func TestTag_updateView(t *testing.T) { assert.EqualValues(t, expectTag, string(actual.View.Tag().Values), testCase.description) } } + +func boolPtr(v bool) *bool { + return &v +} diff --git a/view/template.go b/view/template.go index ae4308723..abf549165 100644 --- a/view/template.go +++ b/view/template.go @@ -24,6 +24,15 @@ type ( Source string `json:",omitempty" yaml:"source,omitempty"` SourceURL string `json:",omitempty" yaml:"sourceURL,omitempty"` Schema *state.Schema `json:",omitempty" yaml:"schema,omitempty"` + // UseParameterStateType makes Velty compile against template parameters + // instead of the view schema when helper state exists outside the named IO type. + UseParameterStateType bool `json:",omitempty" yaml:"useParameterStateType,omitempty"` + // DeclaredParametersOnly prevents global resource parameter binding from + // appending undeclared parameters to this template. + DeclaredParametersOnly bool `json:",omitempty" yaml:"declaredParametersOnly,omitempty"` + // UseResourceParameterLookup allows param/state source lookup to resolve + // against resource parameters in addition to the declared template params. + UseResourceParameterLookup bool `json:",omitempty" yaml:"useResourceParameterLookup,omitempty"` stateType *structology.StateType @@ -92,7 +101,16 @@ func (t *Template) Init(ctx context.Context, resource *Resource, view *View) err if err = t.initTypes(ctx, resource); err != nil { return err } - if rType := t.Schema.Type(); rType != nil { + if t.UseParameterStateType && len(t.Parameters) > 0 { + rType, err := t.Parameters.ReflectType(t.Package(), resource.LookupType(), state.WithSetMarker()) + if err != nil { + return fmt.Errorf("failed to build template parameter state for %s: %w", t._view.Name, err) + } + if rType.Kind() == reflect.Struct { + rType = reflect.PtrTo(rType) + } + t.stateType = structology.NewStateType(rType) + } else if rType := t.Schema.Type(); rType != nil { t.stateType = structology.NewStateType(rType) } @@ -269,6 +287,30 @@ func WithTemplateParameters(parameters ...*state.Parameter) TemplateOption { } } +// WithTemplateUnsafeStateFromParameters configures template evaluation to derive +// the Velty Unsafe state from template parameters rather than the named view schema. +func WithTemplateUnsafeStateFromParameters(enabled bool) TemplateOption { + return func(t *Template) { + t.UseParameterStateType = enabled + } +} + +// WithTemplateDeclaredParametersOnly preserves only explicitly declared +// template parameters during later resource binding. +func WithTemplateDeclaredParametersOnly(enabled bool) TemplateOption { + return func(t *Template) { + t.DeclaredParametersOnly = enabled + } +} + +// WithTemplateResourceParameterLookup allows template parameter source lookup +// to resolve from resource parameters while keeping the declared template state minimal. +func WithTemplateResourceParameterLookup(enabled bool) TemplateOption { + return func(t *Template) { + t.UseResourceParameterLookup = enabled + } +} + // WithTemplateSchema returns with template schema func WithTemplateSchema(schema *state.Schema) TemplateOption { return func(t *Template) { diff --git a/view/view.go b/view/view.go index b297489f1..f09073e6b 100644 --- a/view/view.go +++ b/view/view.go @@ -64,6 +64,7 @@ type ( PublishParent bool `json:",omitempty"` Partitioned *Partitioned Criteria string `json:",omitempty"` + Groupable bool `json:",omitempty"` Selector *Config `json:",omitempty"` Template *Template `json:",omitempty"` @@ -461,6 +462,9 @@ func (v *View) buildViewOptions(aViewType reflect.Type, tag *tags.Tag) ([]Option for _, name := range vTag.Parameters { parameters = append(parameters, state.NewRefParameter(name)) } + if vTag.SummaryURI != "" { + options = append(options, WithSummaryURI(vTag.SummaryURI)) + } } if SQL := tag.SQL; SQL.SQL != "" { tmpl := NewTemplate(string(SQL.SQL), WithTemplateParameters(parameters...)) @@ -936,6 +940,11 @@ func (v *View) ensureColumns(ctx context.Context, resource *Resource) error { if len(v.Columns) != 0 { return nil } + if v.Schema != nil { + if err := v.Schema.LoadTypeIfNeeded(resource.LookupType()); err != nil { + return err + } + } //if scheme type defines sqlx tag, use it as source for column instead of detection if rType := v.Schema.Type(); rType != nil { sType := types.EnsureStruct(rType) @@ -1032,7 +1041,10 @@ func convertIoColumnsToColumns(ioColumns []io.Column, nullable map[string]bool) // ColumnByName returns Column by Column.Name func (v *View) ColumnByName(name string) (*Column, bool) { - if column, ok := v._columns[name]; ok { + if v == nil || v._columns == nil { + return nil, false + } + if column, err := v._columns.Lookup(name); err == nil { return column, true } @@ -1098,6 +1110,7 @@ func (v *View) inherit(view *View) error { setter.SetStringIfEmpty(&v.Module, view.Module) setter.SetStringIfEmpty(&v.Tag, view.Tag) setter.SetBoolIfFalse(&v.PublishParent, view.PublishParent) + setter.SetBoolIfFalse(&v.Groupable, view.Groupable) setter.SetStringIfEmpty(&v.Description, view.Description) @@ -1286,6 +1299,18 @@ func (v *View) IndexedColumns() NamedColumns { return v._columns } +// IsGroupable reports whether the supplied field or column name resolves to a groupable column. +func (v *View) IsGroupable(name string) bool { + if v == nil || len(v._columns) == 0 { + return false + } + column, err := v._columns.Lookup(name) + if err != nil { + return false + } + return column.Groupable +} + func (v *View) markColumnsAsFilterable() error { if len(v.Selector.Constraints.Filterable) == 1 && strings.TrimSpace(v.Selector.Constraints.Filterable[0]) == "*" { for _, column := range v.Columns { diff --git a/view/views.go b/view/views.go index 494b30d2a..9fcc27123 100644 --- a/view/views.go +++ b/view/views.go @@ -72,7 +72,7 @@ func (n *NamespacedView) indexView(aView *View, aPath string) { nsView.Root = true nsView.Namespaces = append(nsView.Namespaces, "") } - if selector.Namespace != "" { + if selector != nil && selector.Namespace != "" { nsView.Namespaces = append(nsView.Namespaces, selector.Namespace) } n.Views = append(n.Views, nsView)