package orm import ( "fmt" "reflect" "strings" ) type schemaSnapshotColumn struct { Document `d:"table:__schemas"` ID int64 `d:"pk"` ModelName string FieldName string FieldType string FieldIndex int IsRelationship bool IsSynthetic bool } func (m *Model) toSnapshotColumns() (ssc []*schemaSnapshotColumn) { for _, field := range m.Fields { ssc = append(ssc, &schemaSnapshotColumn{ ModelName: m.Name, FieldName: field.Name, FieldType: field.Type.String(), FieldIndex: field.Index, }) } for _, rel := range m.Relationships { rt := rel.RelatedType if rel.Kind == reflect.Slice { rt = reflect.SliceOf(rel.RelatedType) } ssc = append(ssc, &schemaSnapshotColumn{ ModelName: m.Name, FieldName: rel.FieldName, FieldType: rt.String(), FieldIndex: rel.Idx, IsRelationship: true, IsSynthetic: rel.Idx < 0, }) } return } func (m *Model) createTableSql() string { var fields []string for _, field := range m.Fields { if !field.isAnonymous() { isStructOrSliceOfStructs := field.Type.Kind() == reflect.Struct || ((field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) && field.Type.Elem().Kind() == reflect.Struct) if field.PrimaryKey { fields = append(fields, fmt.Sprintf("%s %s PRIMARY KEY", field.ColumnName, field.ColumnType)) } else if !isStructOrSliceOfStructs || field.ColumnType != "" { lalala := fmt.Sprintf("%s %s", field.ColumnName, field.ColumnType) if !field.Nullable { lalala += " NOT NULL" } lalala += fmt.Sprintf(" DEFAULT %v", defaultColumnValue(field.Type)) fields = append(fields, lalala) } } else { ft := field.Type for ft.Kind() == reflect.Pointer { ft = ft.Elem() } for i := range ft.NumField() { efield := field.Type.Field(i) ctype := columnType(efield.Type, false, false) if ctype != "" { def := fmt.Sprintf("%s %s NOT NULL DEFAULT %v", pascalToSnakeCase(efield.Name), ctype, defaultColumnValue(efield.Type)) fields = append(fields, def) } } } } inter := strings.Join(fields, ", ") return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);", m.TableName, inter) } func (m *Model) createJoinTableSql(relName string) string { ref, ok := m.Relationships[relName] if !ok { return "" } aTable := m.TableName joinTableName := ref.ComputeJoinTable() fct := serialToRegular(ref.primaryID().ColumnType) rct := serialToRegular(ref.relatedID().ColumnType) pkSection := fmt.Sprintf(",\nPRIMARY KEY (%s_id, %s_id)", aTable, ref.RelatedModel.TableName, ) if ref.m2mIsh() { pkSection = "" } return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( %s_id %s REFERENCES %s(%s) ON DELETE CASCADE, %s_id %s REFERENCES %s(%s) ON DELETE CASCADE %s );`, joinTableName, ref.Model.TableName, fct, ref.Model.TableName, ref.Model.Fields[ref.Model.IDField].ColumnName, ref.RelatedModel.TableName, rct, ref.RelatedModel.TableName, ref.RelatedModel.Fields[ref.RelatedModel.IDField].ColumnName, pkSection, ) } func (m *Model) generateConstraints(engine *Engine) error { for _, rel := range m.Relationships { field := rel.relatedID() if rel.Type != ManyToMany && rel.Type != HasMany && !rel.m2mIsh() { colType := serialToRegular(field.ColumnType) if !field.Nullable && !rel.Nullable { colType += " NOT NULL" } /*constraint := fmt.Sprintf("%s %s REFERENCES %s(%s)", pascalToSnakeCase(rel.joinField()), colType, rel.RelatedModel.TableName, field.ColumnName) if rel.Type != ManyToOne && rel.Type != BelongsTo { constraint += " ON DELETE CASCADE ON UPDATE CASCADE" }*/ fk := fmt.Sprintf("fk_%s", pascalToSnakeCase(capitalizeFirst(rel.Model.Name)+rel.FieldName+rel.relatedID().Name)) q := fmt.Sprintf(`ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s %s, ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE ON UPDATE CASCADE;`, rel.Model.TableName, pascalToSnakeCase(rel.joinField()), colType, fk, pascalToSnakeCase(rel.joinField()), rel.RelatedModel.TableName, field.ColumnName, ) dq := fmt.Sprintf(`ALTER TABLE %s DROP CONSTRAINT IF EXISTS %s;`, m.TableName, fk) fmt.Printf("%s\n%s\n", dq, q) if _, err := engine.conn.Exec(engine.ctx, dq); err != nil { return err } if _, err := engine.conn.Exec(engine.ctx, q); err != nil { return err } } } return nil } func (m *Model) migrate(engine *Engine) error { sql := m.createTableSql() fmt.Println(sql) if !engine.dryRun { _, err := engine.conn.Exec(engine.ctx, sql) if err != nil { return err } } for relName, rel := range m.Relationships { relkey := rel.ComputeJoinTable() if (rel.Type == ManyToMany && !engine.m2mSeen[relkey]) || (rel.Model.embeddedIsh && !rel.RelatedModel.embeddedIsh && rel.Type == HasMany) { if rel.Type == ManyToMany { engine.m2mSeen[relkey] = true engine.m2mSeen[rel.Model.Name] = true engine.m2mSeen[rel.RelatedModel.Name] = true } jtsql := m.createJoinTableSql(relName) fmt.Println(jtsql) if !engine.dryRun { _, err := engine.conn.Exec(engine.ctx, jtsql) if err != nil { return err } } } } return m.generateConstraints(engine) }