Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/pki/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (o *option) Validate() error {

func (o *option) Execute(_ context.Context) error {
if o.user != "" {
clientCSR, err := pki.CreateClientCSR(o.directory, o.domain, o.user)
clientCSR, err := pki.HandleCreateClientCSR(o.directory, o.domain, o.user)
if err != nil {
return fmt.Errorf("failed to create client csr: %w", err)
}
Expand All @@ -82,7 +82,7 @@ func (o *option) Execute(_ context.Context) error {
return nil
}

rootCA, err := pki.CreateAuraeRootCA(o.directory, o.domain)
rootCA, err := pki.HandleCreateAuraeRootCA(o.directory, o.domain)
if err != nil {
return fmt.Errorf("failed to create aurae root ca: %w", err)
}
Expand Down
247 changes: 158 additions & 89 deletions pkg/pki/pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"time"
)

Expand All @@ -49,16 +47,145 @@ type Certificate struct {
PrivateKey string `json:"key" yaml:"key"`
}

func (c *Certificate) GetCertificate() (*x509.Certificate, error) {
crtPem, _ := pem.Decode([]byte(c.Certificate))
if crtPem == nil || crtPem.Type != "CERTIFICATE" {
return nil, fmt.Errorf("failed to decode certificate")
}

crt, err := x509.ParseCertificate(crtPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}

return crt, nil
}

func (c *Certificate) GetCertAsString() string {
return c.Certificate
}

func (c *Certificate) GetPrivateKey() (*rsa.PrivateKey, error) {
pkPem, _ := pem.Decode([]byte(c.PrivateKey))
if pkPem == nil || pkPem.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode private key")
}

pk, err := x509.ParsePKCS1PrivateKey(pkPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

return pk, nil
}

func (c *Certificate) GetPrivateKeyAsString() string {
return c.PrivateKey
}

func (c *Certificate) WriteCertificateToFile(path, filename string) error {
err := createFile(path, filename, c.GetCertAsString())
if err != nil {
return err
}
return nil
}

func (c *Certificate) WritePrivateKeyToFile(path, filename string) error {
err := createFile(path, filename, c.GetPrivateKeyAsString())
if err != nil {
return err
}
return nil
}

type CertificateRequest struct {
CSR string `json:"csr" yaml:"csr"`
PrivateKey string `json:"key" yaml:"key"`
User string `json:"user" yaml:"user"`
}

func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
func (c *CertificateRequest) GetCsr() (*x509.CertificateRequest, error) {
csrPem, _ := pem.Decode([]byte(c.CSR))
if csrPem == nil || csrPem.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("failed to decode certificate request")
}

csr, err := x509.ParseCertificateRequest(csrPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate request: %w", err)
}

return csr, nil
}

func (c *CertificateRequest) GetCsrAsString() string {
return c.CSR
}

func (c *CertificateRequest) GetPrivateKey() (*rsa.PrivateKey, error) {
pkPem, _ := pem.Decode([]byte(c.PrivateKey))
if pkPem == nil || pkPem.Type != "RSA PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode private key")
}

pk, err := x509.ParsePKCS1PrivateKey(pkPem.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}

return pk, nil
}

func (c *CertificateRequest) GetPrivateKeyAsString() string {
return c.PrivateKey
}

func (c *CertificateRequest) WriteCsrToFile(path, filename string) error {
err := createFile(path, filename, c.GetCsrAsString())
if err != nil {
return err
}
return nil
}

func (c *CertificateRequest) WritePrivateKeyToFile(path, filename string) error {
err := createFile(path, filename, c.GetPrivateKeyAsString())
if err != nil {
return err
}
return nil
}

func HandleCreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
crtPem, keyPem, err := createCA(domainName)
if err != nil {
return nil, err
}

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
}

if path != "" {
err = ca.WriteCertificateToFile(path, "ca.crt")
if err != nil {
return ca, err
}
err = ca.WritePrivateKeyToFile(path, "ca.key")
if err != nil {
return ca, err
}
}

return ca, nil
}

func createCA(domainName string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate private key: %w", err)
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -75,7 +202,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to generate serial number: %w", err)
return nil, nil, fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
Expand Down Expand Up @@ -107,7 +234,7 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

crtBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return &Certificate{}, fmt.Errorf("failed to create certificate: %w", err)
return nil, nil, fmt.Errorf("failed to create certificate: %w", err)
}

crtPem := pem.EncodeToMemory(&pem.Block{
Expand All @@ -117,28 +244,42 @@ func CreateAuraeRootCA(path string, domainName string) (*Certificate, error) {

keyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(priv),
Bytes: crtBytes,
})

ca := &Certificate{
Certificate: string(crtPem),
PrivateKey: string(keyPem),
return crtPem, keyPem, nil
}

func HandleCreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
csrPem, keyPem, err := createClientCSR(domain, user)
if err != nil {
return &CertificateRequest{}, err
}

csr := &CertificateRequest{
CSR: string(csrPem),
PrivateKey: string(keyPem),
User: user,
}

if path != "" {
err = createCAFiles(path, ca)
err = csr.WriteCsrToFile(path, fmt.Sprintf("client.%s.csr", csr.User))
if err != nil {
return ca, err
return csr, err
}
err = csr.WritePrivateKeyToFile(path, fmt.Sprintf("client.%s.key", csr.User))
if err != nil {
return csr, err
}
}

return ca, nil
return csr, nil
}

func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
func createClientCSR(domain, user string) ([]byte, []byte, error) {
priv, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return &CertificateRequest{}, fmt.Errorf("failed to generate private key: %w", err)
return []byte{}, []byte{}, fmt.Errorf("failed to generate private key: %w", err)
}

subj := pkix.Name{
Expand All @@ -158,7 +299,7 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {

csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &template, priv)
if err != nil {
return &CertificateRequest{}, fmt.Errorf("could not create certificate request: %w", err)
return []byte{}, []byte{}, fmt.Errorf("could not create certificate request: %w", err)
}

csrPem := pem.EncodeToMemory(&pem.Block{
Expand All @@ -171,77 +312,5 @@ func CreateClientCSR(path, domain, user string) (*CertificateRequest, error) {
Bytes: x509.MarshalPKCS1PrivateKey(priv),
})

csr := &CertificateRequest{
CSR: string(csrPem),
PrivateKey: string(keyPem),
User: user,
}

if path != "" {
err = createCsrFiles(path, csr)
if err != nil {
return csr, err
}
}

return csr, nil
}

func createCAFiles(path string, ca *Certificate) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}

crtPath := filepath.Join(path, "ca.crt")
keyPath := filepath.Join(path, "ca.key")

err = writeStringToFile(crtPath, ca.Certificate)
if err != nil {
return err
}

err = writeStringToFile(keyPath, ca.PrivateKey)
if err != nil {
return err
}
return nil
}

func createCsrFiles(path string, ca *CertificateRequest) error {
path = filepath.Clean(path)
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}

csrPath := filepath.Join(path, fmt.Sprintf("client.%s.csr", ca.User))
keyPath := filepath.Join(path, fmt.Sprintf("client.%s.key", ca.User))

err = writeStringToFile(csrPath, ca.CSR)
if err != nil {
return err
}

err = writeStringToFile(keyPath, ca.PrivateKey)
if err != nil {
return err
}
return nil
}

func writeStringToFile(p string, s string) error {
f, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", p, err)
}
defer f.Close()

_, err = f.WriteString(s)
if err != nil {
return fmt.Errorf("failed to write file %s: %w", p, err)
}

return nil
return csrPem, keyPem, nil
}
10 changes: 5 additions & 5 deletions pkg/pki/pki_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestCreateAuraeRootCA(t *testing.T) {
t.Run("createAuraeCA", func(t *testing.T) {
domainName := "unsafe.aurae.io"

auraeCa, err := CreateAuraeRootCA("", "unsafe.aurae.io")
auraeCa, err := HandleCreateAuraeRootCA("", "unsafe.aurae.io")
if err != nil {
t.Errorf("could not create auraeCA")
}
Expand All @@ -64,7 +64,7 @@ func TestCreateAuraeRootCA(t *testing.T) {
path := "_tmp/pki"
domainName := "unsafe.aurae.io"

_, err := CreateAuraeRootCA(path, domainName)
_, err := HandleCreateAuraeRootCA(path, domainName)
if err != nil {
t.Errorf("could not create auraeCA")
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestCreateAuraeRootCA(t *testing.T) {

func TestCreateCSR(t *testing.T) {
t.Run("createCSR", func(t *testing.T) {
clientCsr, err := CreateClientCSR("", "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR("", "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could not create csr")
}
Expand All @@ -121,7 +121,7 @@ func TestCreateCSR(t *testing.T) {

t.Run("createCSR with local files", func(t *testing.T) {
path := "_tmp/pki"
clientCsr, err := CreateClientCSR(path, "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR(path, "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could not create csr")
}
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestCreateCSR(t *testing.T) {
}

// Genenerate a new CSR
clientCsr, err := CreateClientCSR("", "unsafe.aurae.io", "christoph")
clientCsr, err := HandleCreateClientCSR("", "unsafe.aurae.io", "christoph")
if err != nil {
t.Errorf("could create csr")
}
Expand Down
Loading