Skip to content
Open
11 changes: 11 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Package api is intended to be the future public API for sqlc.
//
// The shape of this package is inspired by esbuild's Build API
// (https://pkg.go.dev/github.com/evanw/esbuild/pkg/api#hdr-Build_API): a small
// surface area of options structs and result structs that lets callers drive
// sqlc programmatically without going through the CLI.
//
// Today the package lives under internal/ while the API stabilises. Once the
// surface settles it is expected to graduate to pkg/api so it can be imported
// by external Go programs.
package api
108 changes: 108 additions & 0 deletions internal/api/codegen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package api

import (
"context"
"encoding/json"
"fmt"
"runtime/trace"

"google.golang.org/grpc"

"github.com/sqlc-dev/sqlc/internal/codegen/golang"
genjson "github.com/sqlc-dev/sqlc/internal/codegen/json"
"github.com/sqlc-dev/sqlc/internal/compiler"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/config/convert"
"github.com/sqlc-dev/sqlc/internal/ext"
"github.com/sqlc-dev/sqlc/internal/ext/process"
"github.com/sqlc-dev/sqlc/internal/ext/wasm"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func findPlugin(conf config.Config, name string) (*config.Plugin, error) {
for _, plug := range conf.Plugins {
if plug.Name == name {
return &plug, nil
}
}
return nil, fmt.Errorf("plugin not found")
}

func codegen(ctx context.Context, combo config.CombinedSettings, sql outputPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) {
defer trace.StartRegion(ctx, "codegen").End()
req := codeGenRequest(result, combo)
var handler grpc.ClientConnInterface
var out string
switch {
case sql.Plugin != nil:
out = sql.Plugin.Out
plug, err := findPlugin(combo.Global, sql.Plugin.Plugin)
if err != nil {
return "", nil, fmt.Errorf("plugin not found: %s", err)
}

switch {
case plug.Process != nil:
handler = &process.Runner{
Cmd: plug.Process.Cmd,
Env: plug.Env,
Format: plug.Process.Format,
}
case plug.WASM != nil:
handler = &wasm.Runner{
URL: plug.WASM.URL,
SHA256: plug.WASM.SHA256,
Env: plug.Env,
}
default:
return "", nil, fmt.Errorf("unsupported plugin type")
}

opts, err := convert.YAMLtoJSON(sql.Plugin.Options)
if err != nil {
return "", nil, fmt.Errorf("invalid plugin options: %w", err)
}
req.PluginOptions = opts

global, found := combo.Global.Options[plug.Name]
if found {
opts, err := convert.YAMLtoJSON(global)
if err != nil {
return "", nil, fmt.Errorf("invalid global options: %w", err)
}
req.GlobalOptions = opts
}

case sql.Gen.Go != nil:
out = combo.Go.Out
handler = ext.HandleFunc(golang.Generate)
opts, err := json.Marshal(sql.Gen.Go)
if err != nil {
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
}
req.PluginOptions = opts

if combo.Global.Overrides.Go != nil {
opts, err := json.Marshal(combo.Global.Overrides.Go)
if err != nil {
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
}
req.GlobalOptions = opts
}

case sql.Gen.JSON != nil:
out = combo.JSON.Out
handler = ext.HandleFunc(genjson.Generate)
opts, err := json.Marshal(sql.Gen.JSON)
if err != nil {
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
}
req.PluginOptions = opts

default:
return "", nil, fmt.Errorf("missing language backend")
}
client := plugin.NewCodegenServiceClient(handler)
resp, err := client.Generate(ctx, req)
return out, resp, err
}
109 changes: 109 additions & 0 deletions internal/api/diff.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package api

import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime/trace"
"sort"

"github.com/cubicdaiya/gonp"
)

func writeFiles(ctx context.Context, files map[string]string, stderr io.Writer) error {
defer trace.StartRegion(ctx, "writefiles").End()
for filename, source := range files {
if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil {
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
return err
}
if err := os.WriteFile(filename, []byte(source), 0644); err != nil {
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
return err
}
}
return nil
}

func diffFiles(ctx context.Context, baseDir string, files map[string]string, stderr io.Writer) error {
defer trace.StartRegion(ctx, "checkfiles").End()
var errored bool

if baseDir == "" {
baseDir, _ = os.Getwd()
}

keys := make([]string, 0, len(files))
for k := range files {
keys = append(keys, k)
}
sort.Strings(keys)

for _, filename := range keys {
source := files[filename]
if _, err := os.Stat(filename); errors.Is(err, os.ErrNotExist) {
errored = true
continue
}
existing, err := os.ReadFile(filename)
if err != nil {
errored = true
fmt.Fprintf(stderr, "%s: %s\n", filename, err)
continue
}
d := gonp.New(getLines(existing), getLines([]byte(source)))
d.Compose()
uniHunks := filterHunks(d.UnifiedHunks())

if len(uniHunks) > 0 {
errored = true
label := filename
if baseDir != "" {
if rel, err := filepath.Rel(baseDir, filename); err == nil {
label = "/" + rel
}
}
fmt.Fprintf(stderr, "--- a%s\n", label)
fmt.Fprintf(stderr, "+++ b%s\n", label)
d.FprintUniHunks(stderr, uniHunks)
}
}
if errored {
return errors.New("diff found")
}
return nil
}

func getLines(f []byte) []string {
fp := bytes.NewReader(f)
scanner := bufio.NewScanner(fp)
lines := make([]string, 0)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
return lines
}

func filterHunks[T gonp.Elem](uniHunks []gonp.UniHunk[T]) []gonp.UniHunk[T] {
var out []gonp.UniHunk[T]
for i, uniHunk := range uniHunks {
var changed bool
for _, e := range uniHunk.GetChanges() {
switch e.GetType() {
case gonp.SesDelete:
changed = true
case gonp.SesAdd:
changed = true
}
}
if changed {
out = append(out, uniHunks[i])
}
}
return out
}
Loading