Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions cmd/pki/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ package create

import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"os"

aeCMD "github.com/aurae-runtime/ae/cmd"
"github.com/aurae-runtime/ae/pkg/cli"
Expand All @@ -48,6 +51,11 @@ type option struct {
directory string
domain string
user string
caPath string
caKeyPath string
csrPath string
csr *pki.CertificateRequest
ca *pki.Certificate
silent bool
writer io.Writer
}
Expand All @@ -60,17 +68,97 @@ func (o *option) Complete(args []string) error {
return fmt.Errorf("too many arguments for command 'create', expect %d, got %d", 1, len(args))
}

if o.caPath != "" {
b, err := os.ReadFile(o.caPath)
if err != nil {
return fmt.Errorf("failed to read ca certificate: %w", err)
}

o.ca = &pki.Certificate{}
o.ca.Certificate = string(b)

if o.caKeyPath != "" {
b, err := os.ReadFile(o.caKeyPath)
if err != nil {
return fmt.Errorf("failed to read ca private key: %w", err)
}

o.ca.PrivateKey = string(b)
} else {
return fmt.Errorf("must provide --caKey and --csr when using --ca")
}

if o.csrPath != "" {
b, err := os.ReadFile(o.csrPath)
if err != nil {
return fmt.Errorf("failed to read csr: %w", err)
}

o.csr = &pki.CertificateRequest{}
o.csr.CSR = string(b)
} else {
return fmt.Errorf("must provide --caKey and --csr when using --ca")
}
}

if o.csrPath != "" {
b, err := os.ReadFile(o.csrPath)
if err != nil {
return fmt.Errorf("failed to read csr: %w", err)
}

o.csr = &pki.CertificateRequest{}
o.csr.CSR = string(b)
}

o.domain = args[0]

return nil
}

func (o *option) Validate() error {
if o.caPath != "" {
caPem, _ := pem.Decode([]byte(o.ca.Certificate))
_, err := x509.ParseCertificate(caPem.Bytes)
if err != nil {
return fmt.Errorf("could not parse ca file")
}
}

if o.caKeyPath != "" {
caKeyPem, _ := pem.Decode([]byte(o.ca.PrivateKey))
_, err := x509.ParsePKCS1PrivateKey(caKeyPem.Bytes)
if err != nil {
return fmt.Errorf("could not parse key file")
}
}

if o.csrPath != "" {
csrPem, _ := pem.Decode([]byte(o.csr.CSR))
_, err := x509.ParseCertificateRequest(csrPem.Bytes)
if err != nil {
return fmt.Errorf("could not parse csr file")
}
}

return nil
}

func (o *option) Execute(_ context.Context) error {
if o.user != "" {

if o.caPath != "" {
clientCrt, err := pki.CreateClientCertificate(o.directory, o.csr.CSR, o.ca, o.user)
if err != nil {
return fmt.Errorf("failed to create client certificate: %w", err)
}
if !o.silent {
o.outputFormat.ToPrinter().Print(o.writer, &clientCrt)
}

return nil
}

clientCSR, err := pki.CreateClientCSR(o.directory, o.domain, o.user)
if err != nil {
return fmt.Errorf("failed to create client csr: %w", err)
Expand Down Expand Up @@ -118,6 +206,9 @@ ae pki create --dir ./pki/ my.domain.com`,
o.outputFormat.AddFlags(cmd)
cmd.Flags().StringVarP(&o.directory, "dir", "d", o.directory, "Output directory to store CA files.")
cmd.Flags().StringVarP(&o.user, "user", "u", o.user, "Creates client certificate for a given user.")
cmd.Flags().StringVar(&o.caPath, "ca", o.caPath, "Use the given CA certificate.")
cmd.Flags().StringVar(&o.caKeyPath, "caKey", o.caKeyPath, "The corresponding CA key.")
cmd.Flags().StringVar(&o.csrPath, "csr", o.csrPath, "CSR input file.")
cmd.Flags().BoolVarP(&o.silent, "silent", "s", o.silent, "Silent mode, omits output")

return cmd
Expand Down
140 changes: 126 additions & 14 deletions pkg/pki/pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
package pki

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
Expand All @@ -55,10 +56,10 @@ type CertificateRequest struct {
User string `json:"user" yaml:"user"`
}

func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
func createCA(domainName string) ([]byte, *rsa.PrivateKey, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err)
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -75,7 +76,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err)
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
Expand Down Expand Up @@ -107,22 +108,31 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

crtBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err)
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}

crtPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: crtBytes,
})
return crtBytes, priv, nil
}

keyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})
func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
crtBytes, priv, err := createCA(domainName)
if err != nil {
return nil, err
}

crtPem, err := getPemBuffer(crtBytes, "CERTIFICATE")
if err != nil {
return nil, err
}

keyPem, err := getPemBuffer(x509.MarshalPKCS1PrivateKey(priv), "RSA PRIVATE KEY")
if err != nil {
return nil, err
}

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
Certificate: crtPem.String(),
PrivateKey: keyPem.String(),
}

if path != "" {
Expand Down Expand Up @@ -187,6 +197,108 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
return csr, nil
}

func CreateClientCertificate(path, csrStr string, ca *Certificate, user string) (*Certificate, error) {
csrPem, _ := pem.Decode([]byte(csrStr))
if csrPem == nil || csrPem.Type != "CERTIFICATE REQUEST" {
return &Certificate{}, fmt.Errorf("failed to decode certificate request")
}

csr, err := x509.ParseCertificateRequest(csrPem.Bytes)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to parse certificate request: %w", err)
}

caCrtPem, _ := pem.Decode([]byte(ca.Certificate))
if caCrtPem == nil || caCrtPem.Type != "CERTIFICATE" {
return &Certificate{}, fmt.Errorf("failed to decode certificate")
}

caCrt, err := x509.ParseCertificate(caCrtPem.Bytes)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to parse certificate: %w", err)
}

caPrivPem, _ := pem.Decode([]byte(ca.PrivateKey))
if caPrivPem == nil || caPrivPem.Type != "RSA PRIVATE KEY" {
return &Certificate{}, fmt.Errorf("failed to decode private key")
}

caPriv, err := x509.ParsePKCS1PrivateKey(caPrivPem.Bytes)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to parse private key: %w", err)
}

now := time.Now()

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
Subject: csr.Subject,
NotBefore: now,
NotAfter: now.Add(24 * time.Hour * 9999),
SerialNumber: serialNumber,
IsCA: false,
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageContentCommitment | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
}

crtPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err)
}

// TODO: is this the correct Subject Key Identifier?
pubHash := sha1.Sum(crtPrivKey.PublicKey.N.Bytes())
template.SubjectKeyId = pubHash[:]

crtPrivPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(crtPrivKey),
})

clientCrtBytes, err := x509.CreateCertificate(rand.Reader, &template, caCrt, &crtPrivKey.PublicKey, caPriv)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err)
}

clientCrtPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: clientCrtBytes,
})

clientCert := &Certificate{
Certificate: string(clientCrtPem),
PrivateKey: string(crtPrivPem),
}

// if path != "" {
// err = createCsrFiles(path, csr)
// if err != nil {
// return csr, err
// }
// }

return clientCert, nil
}

func getPemBuffer(b []byte, t string) (*bytes.Buffer, error) {
// var certBytes []byte
pemBuffer := bytes.NewBuffer([]byte{})
err := pem.Encode(pemBuffer, &pem.Block{
Type: t,
Bytes: b,
})
if err != nil {
return &bytes.Buffer{}, fmt.Errorf("failed to write \"%s\" pem buffer of type: %w", t, err)
}

return pemBuffer, nil
}

func createCAFiles(path string, ca *Certificate) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
Expand Down