001/*
002 * Copyright 2015-2022 Transmogrify LLC, 2022-2026 Revetware LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.pyranid;
018
019import org.jspecify.annotations.NonNull;
020import org.jspecify.annotations.Nullable;
021
022import javax.sql.DataSource;
023import java.sql.Connection;
024import java.sql.DatabaseMetaData;
025import java.sql.SQLException;
026import java.util.Locale;
027
028import static java.util.Objects.requireNonNull;
029
030/**
031 * Identifies different types of databases, which allows for special platform-specific handling.
032 *
033 * @author <a href="https://www.revetkn.com">Mark Allen</a>
034 * @since 1.0.0
035 */
036public enum DatabaseType {
037        /**
038         * A database which requires no special handling.
039         */
040        GENERIC,
041        /**
042         * A PostgreSQL database.
043         */
044        POSTGRESQL,
045        /**
046         * An Oracle database.
047         */
048        ORACLE,
049        /**
050         * A MySQL database.
051         *
052         * @since 4.3.0
053         */
054        MYSQL,
055        /**
056         * A MariaDB database.
057         *
058         * @since 4.3.0
059         */
060        MARIA_DB,
061        /**
062         * A SQLite database.
063         *
064         * @since 4.3.0
065         */
066        SQLITE,
067        /**
068         * A Microsoft SQL Server database.
069         *
070         * @since 4.3.0
071         */
072        SQL_SERVER;
073
074        /**
075         * Determines the type of database to which the given {@code dataSource} connects.
076         * <p>
077         * Note: this will establish a {@link Connection} to the database.
078         *
079         * @param dataSource the database connection factory
080         * @return the type of database
081         * @throws DatabaseException if an exception occurs while attempting to read database metadata
082         */
083        @NonNull
084        public static DatabaseType fromDataSource(@NonNull DataSource dataSource) {
085                requireNonNull(dataSource);
086                
087                try (Connection connection = dataSource.getConnection()) {
088                        return fromConnection(connection);
089                } catch (SQLException e) {
090                        throw new DatabaseException("Unable to connect to database to determine its type", e);
091                }
092        }
093
094        /**
095         * Determines the type of database represented by the given {@code connection}.
096         *
097         * @param connection an active database connection
098         * @return the type of database
099         * @throws DatabaseException if an exception occurs while attempting to read database metadata
100         */
101        @NonNull
102        public static DatabaseType fromConnection(@NonNull Connection connection) {
103                requireNonNull(connection);
104
105                try {
106                        DatabaseMetaData databaseMetaData = connection.getMetaData();
107                        String databaseProductName = databaseMetaData.getDatabaseProductName();
108                        String databaseProductVersion = databaseProductVersion(databaseMetaData);
109                        String url = databaseMetaData.getURL();
110                        String driverName = databaseMetaData.getDriverName();
111
112                        // All of our checks are against databases with English names
113                        String databaseProductNameLowercase = databaseProductName == null ? "" : databaseProductName.toLowerCase(Locale.ENGLISH);
114                        String databaseProductVersionLowercase = databaseProductVersion == null ? "" : databaseProductVersion.toLowerCase(Locale.ENGLISH);
115                        String urlLowercase = url == null ? "" : url.toLowerCase(Locale.ENGLISH);
116                        String driverNameLowercase = driverName == null ? "" : driverName.toLowerCase(Locale.ENGLISH);
117
118                        // Prefer product name
119                        if (databaseProductNameLowercase.startsWith("oracle"))
120                                return DatabaseType.ORACLE;
121
122                        // Strict match for PostgreSQL
123                        if (databaseProductNameLowercase.contains("postgresql") || databaseProductNameLowercase.equals("postgres"))  // some proxies shorten it
124                                return DatabaseType.POSTGRESQL;
125
126                        if (databaseProductNameLowercase.contains("mariadb"))
127                                return DatabaseType.MARIA_DB;
128
129                        if (databaseProductNameLowercase.contains("mysql"))
130                                return mysqlFamilyDatabaseType(databaseProductVersionLowercase, driverNameLowercase);
131
132                        if (databaseProductNameLowercase.contains("sqlite"))
133                                return DatabaseType.SQLITE;
134
135                        if (isSqlServerProductName(databaseProductNameLowercase))
136                                return DatabaseType.SQL_SERVER;
137
138                        // Fallbacks if product name is absent/weird but we're clearly using a vendor driver/URL
139                        if (urlLowercase.startsWith("jdbc:postgresql:") || driverNameLowercase.contains("postgresql"))
140                                return DatabaseType.POSTGRESQL;
141
142                        if (urlLowercase.startsWith("jdbc:oracle:") || driverNameLowercase.contains("oracle jdbc"))
143                                return DatabaseType.ORACLE;
144
145                        if (urlLowercase.startsWith("jdbc:mariadb:") || driverNameLowercase.contains("mariadb"))
146                                return DatabaseType.MARIA_DB;
147
148                        if (isMysqlUrl(urlLowercase) || driverNameLowercase.contains("mysql"))
149                                return mysqlFamilyDatabaseType(databaseProductVersionLowercase, driverNameLowercase);
150
151                        if (urlLowercase.startsWith("jdbc:sqlite:") || driverNameLowercase.contains("sqlite"))
152                                return DatabaseType.SQLITE;
153
154                        if (isSqlServerUrl(urlLowercase) || isSqlServerDriverName(driverNameLowercase))
155                                return DatabaseType.SQL_SERVER;
156
157                        return DatabaseType.GENERIC;
158                } catch (SQLException e) {
159                        throw new DatabaseException("Unable to inspect database metadata to determine its type", e);
160                }
161        }
162
163        @Nullable
164        private static String databaseProductVersion(@NonNull DatabaseMetaData databaseMetaData) {
165                requireNonNull(databaseMetaData);
166
167                try {
168                        return databaseMetaData.getDatabaseProductVersion();
169                } catch (SQLException e) {
170                        return null;
171                }
172        }
173
174        @NonNull
175        DatabaseDialect dialect() {
176                return switch (this) {
177                        case POSTGRESQL -> PostgresDialect.INSTANCE;
178                        case ORACLE -> OracleDialect.INSTANCE;
179                        case MYSQL -> MySqlDialect.INSTANCE;
180                        case MARIA_DB -> MariaDbDialect.INSTANCE;
181                        case SQLITE -> SqliteDialect.INSTANCE;
182                        case SQL_SERVER -> SqlServerDialect.INSTANCE;
183                        case GENERIC -> GenericDialect.INSTANCE;
184                };
185        }
186
187        @NonNull
188        private static DatabaseType mysqlFamilyDatabaseType(@NonNull String databaseProductVersionLowercase,
189                                                                                                                                                                                                                 @NonNull String driverNameLowercase) {
190                requireNonNull(databaseProductVersionLowercase);
191                requireNonNull(driverNameLowercase);
192
193                if (databaseProductVersionLowercase.contains("mariadb") || driverNameLowercase.contains("mariadb"))
194                        return DatabaseType.MARIA_DB;
195
196                return DatabaseType.MYSQL;
197        }
198
199        private static boolean isMysqlUrl(@NonNull String urlLowercase) {
200                requireNonNull(urlLowercase);
201
202                return urlLowercase.startsWith("jdbc:mysql:") || urlLowercase.startsWith("jdbc:mysql+srv:");
203        }
204
205        private static boolean isSqlServerProductName(@NonNull String databaseProductNameLowercase) {
206                requireNonNull(databaseProductNameLowercase);
207
208                return databaseProductNameLowercase.contains("microsoft sql server")
209                                || databaseProductNameLowercase.equals("sql server")
210                                || databaseProductNameLowercase.equals("sqlserver");
211        }
212
213        private static boolean isSqlServerUrl(@NonNull String urlLowercase) {
214                requireNonNull(urlLowercase);
215
216                return urlLowercase.startsWith("jdbc:sqlserver:") || urlLowercase.startsWith("jdbc:jtds:sqlserver:");
217        }
218
219        private static boolean isSqlServerDriverName(@NonNull String driverNameLowercase) {
220                requireNonNull(driverNameLowercase);
221
222                return driverNameLowercase.contains("microsoft jdbc driver")
223                                || driverNameLowercase.contains("sql server")
224                                || driverNameLowercase.contains("jtds");
225        }
226}